Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/auth/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"time"

"github.com/golang-jwt/jwt/v5"

"github.com/stacklok/toolhive/pkg/auth/token"
)

// AnonymousMiddleware creates an HTTP middleware that sets up anonymous claims.
Expand All @@ -33,7 +35,7 @@ func AnonymousMiddleware(next http.Handler) http.Handler {

// Add the anonymous claims to the request context using the same key
// as the JWT middleware for consistency
ctx := context.WithValue(r.Context(), ClaimsContextKey{}, claims)
ctx := context.WithValue(r.Context(), token.ClaimsContextKey{}, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
4 changes: 3 additions & 1 deletion pkg/auth/anonymous_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/auth/token"
)

func TestAnonymousMiddleware(t *testing.T) {
t.Parallel()
// Create a test handler that checks for claims in the context
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := GetClaimsFromContext(r.Context())
claims, ok := token.GetClaimsFromContext(r.Context())
require.True(t, ok, "Expected claims to be present in context")

// Verify the anonymous claims
Expand Down
124 changes: 124 additions & 0 deletions pkg/auth/compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package auth

import (
"context"
"net/http"

"github.com/golang-jwt/jwt/v5"

"github.com/stacklok/toolhive/pkg/auth/middleware"
"github.com/stacklok/toolhive/pkg/auth/token"
"github.com/stacklok/toolhive/pkg/auth/token/providers"
)

// Compatibility types and constants - aliasing to new packages

// TokenValidator validates JWT or opaque tokens using OIDC configuration.
// Deprecated: Use token.Validator instead
type TokenValidator = token.Validator

// TokenValidatorConfig contains configuration for the token validator.
// Deprecated: Use token.ValidatorConfig instead
type TokenValidatorConfig = token.ValidatorConfig

// TokenIntrospector defines the interface for token introspection providers
// Deprecated: Use token.Introspector instead
type TokenIntrospector = token.Introspector

// ClaimsContextKey is the key used to store claims in the request context.
// Deprecated: Use token.ClaimsContextKey instead
type ClaimsContextKey = token.ClaimsContextKey

// RFC9728AuthInfo represents the OAuth Protected Resource metadata as defined in RFC 9728
// Deprecated: Use discovery.RFC9728AuthInfo instead
type RFC9728AuthInfo struct {
Resource string `json:"resource"`
AuthorizationServers []string `json:"authorization_servers"`
BearerMethodsSupported []string `json:"bearer_methods_supported"`
JWKSURI string `json:"jwks_uri"`
ScopesSupported []string `json:"scopes_supported"`
}

// Common errors - re-exported from token package
var (
// Deprecated: Use token.ErrNoToken instead
ErrNoToken = token.ErrNoToken
// Deprecated: Use token.ErrInvalidToken instead
ErrInvalidToken = token.ErrInvalidToken
// Deprecated: Use token.ErrTokenExpired instead
ErrTokenExpired = token.ErrTokenExpired
// Deprecated: Use token.ErrInvalidIssuer instead
ErrInvalidIssuer = token.ErrInvalidIssuer
// Deprecated: Use token.ErrInvalidAudience instead
ErrInvalidAudience = token.ErrInvalidAudience
// Deprecated: Use token.ErrMissingJWKSURL instead
ErrMissingJWKSURL = token.ErrMissingJWKSURL
// Deprecated: Use token.ErrFailedToFetchJWKS instead
ErrFailedToFetchJWKS = token.ErrFailedToFetchJWKS
// Deprecated: Use token.ErrFailedToDiscoverOIDC instead
ErrFailedToDiscoverOIDC = token.ErrFailedToDiscoverOIDC
// Deprecated: Use token.ErrMissingIssuerAndJWKSURL instead
ErrMissingIssuerAndJWKSURL = token.ErrMissingIssuerAndJWKSURL
)

// Constants
const (
// GoogleTokeninfoURL is the Google OAuth2 tokeninfo endpoint URL
// Deprecated: Use providers.GoogleTokeninfoURL instead
GoogleTokeninfoURL = providers.GoogleTokeninfoURL
)

// Compatibility functions

// NewTokenValidator creates a new token validator.
// Deprecated: Use token.NewValidator instead
func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*TokenValidator, error) {
return token.NewValidator(ctx, config)
}

// NewTokenValidatorConfig creates a new TokenValidatorConfig with the provided parameters
// Deprecated: Use token.NewValidatorConfig instead
func NewTokenValidatorConfig(issuer, audience, jwksURL, clientID string, clientSecret string) *TokenValidatorConfig {
return token.NewValidatorConfig(issuer, audience, jwksURL, clientID, clientSecret)
}

// GetClaimsFromContext retrieves the claims from the request context.
// Deprecated: Use token.GetClaimsFromContext instead
func GetClaimsFromContext(ctx context.Context) (jwt.MapClaims, bool) {
return token.GetClaimsFromContext(ctx)
}

// GetAuthenticationMiddleware returns the appropriate authentication middleware based on the configuration.
// Deprecated: Use middleware.GetAuthenticationMiddleware instead
func GetAuthenticationMiddleware(
ctx context.Context,
oidcConfig *TokenValidatorConfig,
) (func(http.Handler) http.Handler, http.Handler, error) {
if oidcConfig == nil {
return middleware.GetAuthenticationMiddleware(ctx, nil, "")
}

validator, err := token.NewValidator(ctx, *oidcConfig)
if err != nil {
return nil, nil, err
}

return middleware.GetAuthenticationMiddleware(ctx, validator, oidcConfig.ResourceURL)
}

// NewAuthInfoHandler creates an HTTP handler that returns RFC-9728 compliant OAuth Protected Resource metadata
// Deprecated: Use middleware.NewAuthInfoHandler instead
func NewAuthInfoHandler(_, jwksURL, resourceURL string, scopes []string) http.Handler {
return middleware.NewAuthInfoHandler(jwksURL, resourceURL, scopes)
}

// LocalUserMiddleware creates middleware that adds local user claims to the context
// Deprecated: Use middleware.LocalUserMiddleware instead
func LocalUserMiddleware(username string) func(http.Handler) http.Handler {
return middleware.LocalUserMiddleware(username)
}

// EscapeQuotes escapes quotes in a string for use in a quoted-string context.
func EscapeQuotes(s string) string {
return middleware.EscapeQuotes(s)
}
196 changes: 0 additions & 196 deletions pkg/auth/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ package discovery

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
Expand All @@ -21,7 +19,6 @@ import (

"golang.org/x/oauth2"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/oauth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/networking"
Expand Down Expand Up @@ -156,83 +153,6 @@ func detectAuthWithRequest(
return nil, nil
}

// ParseWWWAuthenticate parses the WWW-Authenticate header to extract authentication information
// Supports multiple authentication schemes and complex header formats
func ParseWWWAuthenticate(header string) (*AuthInfo, error) {
// Trim whitespace and handle empty headers
header = strings.TrimSpace(header)
if header == "" {
return nil, fmt.Errorf("empty WWW-Authenticate header")
}

// Check for OAuth/Bearer authentication
// Note: We don't split by comma because Bearer parameters can contain commas in quoted values
if strings.HasPrefix(header, "Bearer") {
authInfo := &AuthInfo{Type: "OAuth"}

// Extract parameters after "Bearer"
params := strings.TrimSpace(strings.TrimPrefix(header, "Bearer"))
if params != "" {
// Parse parameters (realm, scope, resource_metadata, etc.)
realm := ExtractParameter(params, "realm")
if realm != "" {
authInfo.Realm = realm
}

// RFC 9728: Check for resource_metadata parameter
resourceMetadata := ExtractParameter(params, "resource_metadata")
if resourceMetadata != "" {
authInfo.ResourceMetadata = resourceMetadata
}

// Extract error information if present
errorParam := ExtractParameter(params, "error")
if errorParam != "" {
authInfo.Error = errorParam
}

errorDesc := ExtractParameter(params, "error_description")
if errorDesc != "" {
authInfo.ErrorDescription = errorDesc
}
}

return authInfo, nil
}

// Check for OAuth-specific schemes
if strings.HasPrefix(header, "OAuth") {
authInfo := &AuthInfo{Type: "OAuth"}

// Extract parameters after "OAuth"
params := strings.TrimSpace(strings.TrimPrefix(header, "OAuth"))
if params != "" {
// Parse parameters (realm, scope, etc.)
realm := ExtractParameter(params, "realm")
if realm != "" {
authInfo.Realm = realm
}

// RFC 9728: Check for resource_metadata parameter
resourceMetadata := ExtractParameter(params, "resource_metadata")
if resourceMetadata != "" {
authInfo.ResourceMetadata = resourceMetadata
}
}

return authInfo, nil
}

// Currently only OAuth-based authentication is supported
// Basic and Digest authentication are not implemented
if strings.HasPrefix(header, "Basic") || strings.HasPrefix(header, "Digest") {
logger.Debugf("Unsupported authentication scheme: %s", header)
return nil, fmt.Errorf("unsupported authentication scheme: %s", strings.Split(header, " ")[0])
}

return nil, fmt.Errorf("no supported authentication type found in header: %s", header)
}

// DeriveIssuerFromURL attempts to derive the OAuth issuer from the remote URL using general patterns
func DeriveIssuerFromURL(remoteURL string) string {
// Parse the URL to extract the domain
Expand Down Expand Up @@ -268,57 +188,6 @@ func DeriveIssuerFromURL(remoteURL string) string {
return issuer
}

// ExtractParameter extracts a parameter value from an authentication header
// Handles both quoted and unquoted values according to RFC 2617 and RFC 6750
func ExtractParameter(params, paramName string) string {
// Parameters can be separated by comma or space
// Handle both paramName=value and paramName="value" formats

// First try to find the parameter with equals sign
searchStr := paramName + "="
idx := strings.Index(params, searchStr)
if idx == -1 {
return ""
}

// Extract the value after the equals sign
valueStart := idx + len(searchStr)
if valueStart >= len(params) {
return ""
}

remainder := params[valueStart:]

// Check if the value is quoted
if strings.HasPrefix(remainder, `"`) {
// Find the closing quote
endIdx := 1
for endIdx < len(remainder) {
if remainder[endIdx] == '"' && (endIdx == 1 || remainder[endIdx-1] != '\\') {
// Found unescaped closing quote
value := remainder[1:endIdx]
// Unescape any escaped quotes
value = strings.ReplaceAll(value, `\"`, `"`)
return value
}
endIdx++
}
// No closing quote found, return empty
return ""
}

// Unquoted value - find the end (comma, space, or end of string)
endIdx := 0
for endIdx < len(remainder) {
if remainder[endIdx] == ',' || remainder[endIdx] == ' ' {
break
}
endIdx++
}

return strings.TrimSpace(remainder[:endIdx])
}

// DeriveIssuerFromRealm attempts to derive the OAuth issuer from the realm parameter
// According to RFC 8414, the issuer MUST be a URL using the "https" scheme with no query or fragment
func DeriveIssuerFromRealm(realm string) string {
Expand Down Expand Up @@ -546,71 +415,6 @@ func registerDynamicClient(
return registrationResponse, nil
}

// FetchResourceMetadata as specified in RFC 9728
func FetchResourceMetadata(ctx context.Context, metadataURL string) (*auth.RFC9728AuthInfo, error) {
if metadataURL == "" {
return nil, fmt.Errorf("metadata URL is empty")
}

// Validate URL
parsedURL, err := url.Parse(metadataURL)
if err != nil {
return nil, fmt.Errorf("invalid metadata URL: %w", err)
}

// RFC 9728: Must use HTTPS (except for localhost in development)
if parsedURL.Scheme != "https" && parsedURL.Hostname() != "localhost" && parsedURL.Hostname() != "127.0.0.1" {
return nil, fmt.Errorf("metadata URL must use HTTPS: %s", metadataURL)
}

// Create HTTP client with timeout
client := &http.Client{
Timeout: DefaultHTTPTimeout,
Transport: &http.Transport{
TLSHandshakeTimeout: 5 * time.Second,
ResponseHeaderTimeout: 5 * time.Second,
},
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Accept", "application/json")

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("metadata request failed with status %d", resp.StatusCode)
}

// Check content type
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if !strings.Contains(contentType, "application/json") {
return nil, fmt.Errorf("unexpected content type: %s", contentType)
}

// Parse the metadata
const maxResponseSize = 1024 * 1024 // 1MB limit
var metadata auth.RFC9728AuthInfo
if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to parse metadata: %w", err)
}

// RFC 9728 Section 3.3: Validate that the resource value matches
// For now we just check it's not empty
if metadata.Resource == "" {
return nil, fmt.Errorf("metadata missing required 'resource' field")
}

return &metadata, nil
}

// ValidateAndDiscoverAuthServer attempts to validate if a URL is an authorization server
// and discover its actual issuer by fetching its metadata.
// This handles the case where the URL used to fetch metadata differs from the actual issuer
Expand Down
Loading
Loading