Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions pkg/authserver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type RunConfig struct {

// Upstreams configures connections to upstream Identity Providers.
// At least one upstream is required - the server delegates authentication to these providers.
// Currently only a single upstream is supported.
// Multiple upstreams are supported for sequential authorization chains.
Upstreams []UpstreamRunConfig `json:"upstreams" yaml:"upstreams"`

// ScopesSupported lists the OAuth 2.0 scope values advertised in discovery documents.
Expand Down Expand Up @@ -318,7 +318,7 @@ type Config struct {

// Upstreams contains configurations for connecting to upstream IDPs.
// At least one upstream is required - the server delegates authentication to the upstream IDP.
// Currently only a single upstream is supported.
// Multiple upstreams form a sequential authorization chain.
Upstreams []UpstreamConfig

// ScopesSupported lists the OAuth 2.0 scope values advertised in discovery documents.
Expand Down Expand Up @@ -389,10 +389,6 @@ func (c *Config) validateUpstreams() error {
if len(c.Upstreams) == 0 {
return fmt.Errorf("at least one upstream is required")
}
if len(c.Upstreams) > 1 {
return fmt.Errorf("multiple upstreams not yet supported (found %d)", len(c.Upstreams))
}

// Track names for uniqueness checking
seenNames := make(map[string]bool)

Expand Down
4 changes: 2 additions & 2 deletions pkg/authserver/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ func TestConfigValidate(t *testing.T) {
{name: "HMAC too short", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: shortHMAC, Upstreams: validUpstreams}, wantErr: true, errMsg: "HMAC secret must be at least 32 bytes"},
{name: "no upstreams", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC}, wantErr: true, errMsg: "at least one upstream is required"},
{name: "nil upstream config", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: []UpstreamConfig{{Name: "test", Type: UpstreamProviderTypeOAuth2}}}, wantErr: true, errMsg: "oauth2_config is required"},
{name: "multiple upstreams", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: []UpstreamConfig{{Name: "first", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}, {Name: "second", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}}}, wantErr: true, errMsg: "multiple upstreams not yet supported (found 2)"},
{name: "duplicate upstream names", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: []UpstreamConfig{{Name: "same", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}, {Name: "same", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}}}, wantErr: true, errMsg: "multiple upstreams not yet supported"},
{name: "multiple upstreams", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: []UpstreamConfig{{Name: "first", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}, {Name: "second", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}}, AllowedAudiences: []string{"https://mcp.example.com"}}},
{name: "duplicate upstream names", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: []UpstreamConfig{{Name: "same", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}, {Name: "same", Type: UpstreamProviderTypeOAuth2, OAuth2Config: validUpstream}}}, wantErr: true, errMsg: "duplicate upstream name"},
{name: "missing allowed audiences", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: validUpstreams}, wantErr: true, errMsg: "at least one allowed audience is required"},
{name: "empty allowed audiences slice", config: Config{Issuer: "https://example.com", KeyProvider: validKeyProvider, HMACSecrets: validHMAC, Upstreams: validUpstreams, AllowedAudiences: []string{}}, wantErr: true, errMsg: "at least one allowed audience is required"},

Expand Down
310 changes: 310 additions & 0 deletions pkg/authserver/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -1512,3 +1513,312 @@ func TestIntegration_RefreshPreservesUpstreamTokenBinding(t *testing.T) {
require.NotNil(t, tokensAfterRefresh, "upstream tokens should not be nil after refresh")
assert.Equal(t, "default", tokensAfterRefresh.ProviderID, "ProviderID should still be 'default' after refresh")
}

// ============================================================================
// Multi-Upstream Sequential Chain Integration Tests
// ============================================================================

// setupTestServerWithTwoUpstreams creates a test server with two mockoidc instances
// configured as sequential upstream providers. This exercises the multi-upstream
// authorization chain where the callback handler redirects to the next upstream
// after each successful code exchange.
func setupTestServerWithTwoUpstreams(t *testing.T, m1, m2 *mockoidc.MockOIDC) *testServer {
t.Helper()
ctx := context.Background()

// 1. Generate RSA key for signing
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

// 2. Generate HMAC secret
secret := make([]byte, 32)
_, err = rand.Read(secret)
require.NoError(t, err)

// 3. Create storage
stor := storage.NewMemoryStorage()

// 4. Register test client (public client for PKCE)
err = stor.RegisterClient(ctx, &fosite.DefaultClient{
ID: testClientID,
Secret: nil, // public client
RedirectURIs: []string{testRedirectURI},
ResponseTypes: []string{"code"},
GrantTypes: []string{"authorization_code", "refresh_token"},
Scopes: registration.DefaultScopes,
Audience: []string{testAudience},
Public: true,
})
require.NoError(t, err)

// 5. Build upstream configs from the two mockoidc instances.
// Both point their RedirectURI at our auth server's /oauth/callback.
cfg1 := m1.Config()
upstreamCfg1 := &upstream.OAuth2Config{
CommonOAuthConfig: upstream.CommonOAuthConfig{
ClientID: cfg1.ClientID,
ClientSecret: cfg1.ClientSecret,
Scopes: []string{"openid", "profile", "email"},
RedirectURI: testIssuer + "/oauth/callback",
},
AuthorizationEndpoint: m1.AuthorizationEndpoint(),
TokenEndpoint: m1.TokenEndpoint(),
UserInfo: &upstream.UserInfoConfig{
EndpointURL: m1.UserinfoEndpoint(),
FieldMapping: &upstream.UserInfoFieldMapping{
SubjectFields: []string{"sub", "email"},
},
},
}

cfg2 := m2.Config()
upstreamCfg2 := &upstream.OAuth2Config{
CommonOAuthConfig: upstream.CommonOAuthConfig{
ClientID: cfg2.ClientID,
ClientSecret: cfg2.ClientSecret,
Scopes: []string{"openid", "profile", "email"},
RedirectURI: testIssuer + "/oauth/callback",
},
AuthorizationEndpoint: m2.AuthorizationEndpoint(),
TokenEndpoint: m2.TokenEndpoint(),
UserInfo: &upstream.UserInfoConfig{
EndpointURL: m2.UserinfoEndpoint(),
FieldMapping: &upstream.UserInfoFieldMapping{
SubjectFields: []string{"sub", "email"},
},
},
}

// 6. Create the two upstream providers
provider1, err := upstream.NewOAuth2Provider(upstreamCfg1)
require.NoError(t, err)
provider2, err := upstream.NewOAuth2Provider(upstreamCfg2)
require.NoError(t, err)

// Map of provider name to provider for the factory
providers := map[string]upstream.OAuth2Provider{
"provider-1": provider1,
"provider-2": provider2,
}

// 7. Create config with TWO upstreams
serverCfg := Config{
Issuer: testIssuer,
KeyProvider: &testKeyProvider{key: privateKey},
HMACSecrets: servercrypto.NewHMACSecrets(secret),
AccessTokenLifespan: time.Hour,
RefreshTokenLifespan: 24 * time.Hour,
AuthCodeLifespan: 10 * time.Minute,
Upstreams: []UpstreamConfig{
{Name: "provider-1", Type: UpstreamProviderTypeOAuth2, OAuth2Config: upstreamCfg1},
{Name: "provider-2", Type: UpstreamProviderTypeOAuth2, OAuth2Config: upstreamCfg2},
},
AllowedAudiences: []string{testAudience},
}

// 8. Create server using newServer with a factory that returns the correct provider per name
srv, err := newServer(ctx, serverCfg, stor,
withUpstreamFactory(func(_ context.Context, cfg *UpstreamConfig) (upstream.OAuth2Provider, error) {
p, ok := providers[cfg.Name]
if !ok {
return nil, fmt.Errorf("unknown upstream: %s", cfg.Name)
}
return p, nil
}),
)
require.NoError(t, err)

// 9. Create HTTP test server
httpServer := httptest.NewServer(srv.Handler())
t.Cleanup(func() {
httpServer.Close()
require.NoError(t, srv.Close())
})

return &testServer{
Server: httpServer,
PrivateKey: privateKey,
storage: srv.IDPTokenStorage(),
}
}

// TestIntegration_MultiUpstreamSequentialChain tests the complete multi-upstream
// authorization flow where the auth server chains through two upstream providers
// sequentially before issuing an authorization code to the client.
//
// Flow:
// 1. Client -> /authorize -> redirect to provider-1
// 2. provider-1 approves -> /callback -> redirect to provider-2 (chain continues)
// 3. provider-2 approves -> /callback -> 303 to client with auth code
// 4. Client -> /token -> JWT with tsid referencing both providers' tokens
func TestIntegration_MultiUpstreamSequentialChain(t *testing.T) {
t.Parallel()

// Start two independent mock OIDC providers
m1, err := mockoidc.Run()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, m1.Shutdown()) })

m2, err := mockoidc.Run()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, m2.Shutdown()) })

// Queue test users for each provider
m1.QueueUser(&mockoidc.MockUser{
Subject: "user-from-provider-1",
Email: "user1@provider1.example.com",
})
m2.QueueUser(&mockoidc.MockUser{
Subject: "user-from-provider-2",
Email: "user2@provider2.example.com",
})

ts := setupTestServerWithTwoUpstreams(t, m1, m2)
client := noRedirectClient()

verifier := servercrypto.GeneratePKCEVerifier()
challenge := servercrypto.ComputePKCEChallenge(verifier)
clientState := "multi-upstream-client-state"

parsedServerURL, err := url.Parse(ts.Server.URL)
require.NoError(t, err)

// === Leg 1: Client -> /authorize -> provider-1 ===

// Step 1: Start authorization flow on our server
authorizeURL := ts.Server.URL + "/oauth/authorize?" + url.Values{
"client_id": {testClientID},
"redirect_uri": {testRedirectURI},
"state": {clientState},
"code_challenge": {challenge},
"code_challenge_method": {"S256"},
"response_type": {"code"},
"scope": {"openid profile"},
}.Encode()

resp, err := client.Get(authorizeURL)
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode, "expected redirect to provider-1")
m1Location, err := resp.Location()
require.NoError(t, err)
resp.Body.Close()

// Step 2: Follow redirect to provider-1's authorization endpoint (mockoidc auto-approves)
resp, err = client.Get(m1Location.String())
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode, "expected redirect from provider-1 to our callback")
callbackFromM1, err := resp.Location()
require.NoError(t, err)
resp.Body.Close()

// Step 3: Rewrite callback URL to use actual test server (mockoidc redirects to localhost)
callbackFromM1.Scheme = parsedServerURL.Scheme
callbackFromM1.Host = parsedServerURL.Host

// Step 4: Call our /callback with provider-1's code
// This should NOT redirect to the client yet — it should redirect to provider-2
resp, err = client.Get(callbackFromM1.String())
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode,
"expected 302 redirect to provider-2 (chain continues), not 303 to client")
m2Location, err := resp.Location()
require.NoError(t, err)
resp.Body.Close()

// === Leg 2: provider-2 ===

// Step 5: Follow redirect to provider-2's authorization endpoint (mockoidc auto-approves)
resp, err = client.Get(m2Location.String())
require.NoError(t, err)
require.Equal(t, http.StatusFound, resp.StatusCode, "expected redirect from provider-2 to our callback")
callbackFromM2, err := resp.Location()
require.NoError(t, err)
resp.Body.Close()

// Step 6: Rewrite callback URL to use actual test server
callbackFromM2.Scheme = parsedServerURL.Scheme
callbackFromM2.Host = parsedServerURL.Host

// Step 7: Call our /callback with provider-2's code
// All upstreams are now satisfied, so this should redirect to the client with an auth code (303)
resp, err = client.Get(callbackFromM2.String())
require.NoError(t, err)
require.Equal(t, http.StatusSeeOther, resp.StatusCode,
"expected 303 redirect to client with auth code (all upstreams satisfied)")
clientLocation, err := resp.Location()
require.NoError(t, err)
resp.Body.Close()

// === Verify the final redirect to the client ===

// Verify the original client state was preserved through the entire chain
returnedState := clientLocation.Query().Get("state")
assert.Equal(t, clientState, returnedState, "client state should be preserved through the multi-upstream chain")

// Extract authorization code
authCode := clientLocation.Query().Get("code")
require.NotEmpty(t, authCode, "authorization code should be present in the final redirect")

// === Exchange code for tokens ===

tokenData := exchangeCodeForTokens(t, ts.Server.URL, authCode, verifier, testAudience)

// Verify access token is a valid JWT
accessToken, ok := tokenData["access_token"].(string)
require.True(t, ok, "access_token should be a string")
require.NotEmpty(t, accessToken)

parsedToken, err := jwt.ParseSigned(accessToken, []jose.SignatureAlgorithm{jose.RS256})
require.NoError(t, err, "should be able to parse JWT")

var claims map[string]interface{}
err = parsedToken.Claims(ts.PrivateKey.Public(), &claims)
require.NoError(t, err, "JWT signature should be valid")

// Verify standard claims
assert.Equal(t, testIssuer, claims["iss"], "issuer should match")
assert.Equal(t, testClientID, claims["client_id"], "client_id should match")

// Verify subject is from the first upstream (identity provider)
sub, ok := claims["sub"].(string)
require.True(t, ok, "sub claim should be a string")
assert.NotEmpty(t, sub, "sub claim should not be empty")

// Verify tsid claim is present (session ID for upstream token retrieval)
tsid, ok := claims["tsid"].(string)
require.True(t, ok, "tsid claim should be a string")
require.NotEmpty(t, tsid, "tsid claim should not be empty")

// === Verify both providers' tokens are stored ===

ctx := context.Background()

// Provider-1 tokens should be stored
tokens1, err := ts.storage.GetUpstreamTokens(ctx, tsid, "provider-1")
require.NoError(t, err, "provider-1 tokens should be retrievable")
require.NotNil(t, tokens1, "provider-1 tokens should not be nil")
assert.NotEmpty(t, tokens1.AccessToken, "provider-1 access token should not be empty")
assert.Equal(t, "provider-1", tokens1.ProviderID, "provider-1 ProviderID should match")
assert.Equal(t, testClientID, tokens1.ClientID, "provider-1 ClientID should match")
assert.Equal(t, sub, tokens1.UserID, "provider-1 UserID should match JWT sub claim")

// Provider-2 tokens should be stored
tokens2, err := ts.storage.GetUpstreamTokens(ctx, tsid, "provider-2")
require.NoError(t, err, "provider-2 tokens should be retrievable")
require.NotNil(t, tokens2, "provider-2 tokens should not be nil")
assert.NotEmpty(t, tokens2.AccessToken, "provider-2 access token should not be empty")
assert.Equal(t, "provider-2", tokens2.ProviderID, "provider-2 ProviderID should match")
assert.Equal(t, testClientID, tokens2.ClientID, "provider-2 ClientID should match")
assert.Equal(t, sub, tokens2.UserID, "provider-2 UserID should match JWT sub claim")

// Verify upstream subjects trace back to the correct IDPs.
// This proves provider-1 was used as the identity source (its UpstreamSubject
// is from m1's user) and provider-2 contributed only tokens (its UpstreamSubject
// is from m2's user). Both share the same internal UserID (sub) from provider-1.
assert.Contains(t, tokens1.UpstreamSubject, "provider1.example.com",
"provider-1 UpstreamSubject should come from m1's queued user")
assert.Contains(t, tokens2.UpstreamSubject, "provider2.example.com",
"provider-2 UpstreamSubject should come from m2's queued user")
assert.NotEqual(t, tokens1.UpstreamSubject, tokens2.UpstreamSubject,
"upstream subjects should differ (different IDPs)")
}
Loading
Loading