diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 27cab5d1..41778de7 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,6 +8,7 @@ import ( "errors" "time" + "github.com/muhlemmer/gu" "github.com/zitadel/oidc/v3/pkg/oidc" "gopkg.in/square/go-jose.v2" ) @@ -17,7 +18,7 @@ type KeySet struct{} // VerifySignature implments op.KeySet. func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - if ctx.Err() != nil { + if err = ctx.Err(); err != nil { return nil, err } @@ -45,6 +46,16 @@ func init() { } } +type JWTProfileKeyStorage struct{} + +func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return gu.Ptr(WebKey.Public()), nil +} + func signEncodeTokenClaims(claims any) string { payload, err := json.Marshal(claims) if err != nil { @@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) } +func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) { + req := &oidc.JWTTokenRequest{ + Issuer: issuer, + Subject: clientID, + Audience: audience, + ExpiresAt: oidc.FromTime(expiration), + IssuedAt: oidc.FromTime(issuedAt), + } + // make sure the private claim map is set correctly + data, err := json.Marshal(req) + if err != nil { + panic(err) + } + if err = json.Unmarshal(data, req); err != nil { + panic(err) + } + return signEncodeTokenClaims(req), req +} + const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` // These variables always result in a valid token @@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) { return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) } +func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) { + return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration) +} + // ACRVerify is a oidc.ACRVerifier func. func ACRVerify(acr string) error { if acr != ValidACR { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 725715e4..bd96e160 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -63,8 +63,8 @@ type RelyingParty interface { // be used to start a DeviceAuthorization flow. GetDeviceAuthorizationEndpoint() string - // IDTokenVerifier returns the verifier interface used for oidc id_token verification - IDTokenVerifier() IDTokenVerifier + // IDTokenVerifier returns the verifier used for oidc id_token verification + IDTokenVerifier() *IDTokenVerifier // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) @@ -88,7 +88,7 @@ type relyingParty struct { cookieHandler *httphelper.CookieHandler errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier IDTokenVerifier + idTokenVerifier *IDTokenVerifier verifierOpts []VerifierOption signer jose.Signer } @@ -137,7 +137,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string { return rp.endpoints.RevokeURL } -func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { +func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier { if rp.idTokenVerifier == nil { rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) } diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 0cf427a7..3294f407 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -9,19 +9,9 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenVerifier interface { - oidc.Verifier - ClientID() string - SupportedSignAlgs() []string - KeySet() oidc.KeySet - Nonce(context.Context) string - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C claims, err = VerifyIDToken[C](ctx, idToken, v) @@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { +func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -52,27 +42,27 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err = oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { + if err = oidc.CheckAudience(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { + if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } @@ -80,16 +70,18 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil } +type IDTokenVerifier oidc.Verifier + // VerifyAccessToken validates the access token according to // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { @@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl return nil } -// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` -// for `VerifyTokens` and `VerifyIDToken` -func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { - v := &idTokenVerifier{ - issuer: issuer, - clientID: clientID, - keySet: keySet, - offset: time.Second, - nonce: func(_ context.Context) string { +// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification. +func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier { + v := &IDTokenVerifier{ + Issuer: issuer, + ClientID: clientID, + KeySet: keySet, + Offset: time.Second, + Nonce: func(_ context.Context) string { return "" }, } @@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ... } // VerifierOption is the type for providing dynamic options to the IDTokenVerifier -type VerifierOption func(*idTokenVerifier) +type VerifierOption func(*IDTokenVerifier) // WithIssuedAtOffset mitigates the risk of iat to be in the future // because of clock skews with the ability to add an offset to the current time -func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.offset = offset +func WithIssuedAtOffset(offset time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.Offset = offset } } // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now -func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.maxAgeIAT = maxAge +func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.MaxAgeIAT = maxAge } } // WithNonce sets the function to check the nonce func WithNonce(nonce func(context.Context) string) VerifierOption { - return func(v *idTokenVerifier) { - v.nonce = nonce + return func(v *IDTokenVerifier) { + v.Nonce = nonce } } // WithACRVerifier sets the verifier for the acr claim func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { - return func(v *idTokenVerifier) { - v.acr = verifier + return func(v *IDTokenVerifier) { + v.ACR = verifier } } // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { - return func(v *idTokenVerifier) { - v.maxAge = maxAge + return func(v *IDTokenVerifier) { + v.MaxAge = maxAge } } // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { - return func(v *idTokenVerifier) { - v.supportedSignAlgs = algs + return func(v *IDTokenVerifier) { + v.SupportedSignAlgs = algs } } - -type idTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - clientID string - supportedSignAlgs []string - keySet oidc.KeySet - acr oidc.ACRVerifier - maxAge time.Duration - nonce func(ctx context.Context) string -} - -func (i *idTokenVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenVerifier) ClientID() string { - return i.clientID -} - -func (i *idTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenVerifier) Nonce(ctx context.Context) string { - return i.nonce(ctx) -} - -func (i *idTokenVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenVerifier) MaxAge() time.Duration { - return i.maxAge -} diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index 002d65d3..11bf2f9f 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -13,16 +13,16 @@ import ( ) func TestVerifyTokens(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, - clientID: tu.ValidClientID, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } accessToken, _ := tu.ValidAccessToken() atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) @@ -91,15 +91,15 @@ func TestVerifyTokens(t *testing.T) { } func TestVerifyIDToken(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, } tests := []struct { @@ -219,7 +219,7 @@ func TestVerifyIDToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, want := tt.tokenClaims() - verifier.clientID = tt.clientID + verifier.ClientID = tt.clientID got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) if tt.wantErr { assert.Error(t, err) @@ -300,7 +300,7 @@ func TestNewIDTokenVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenVerifier + want *IDTokenVerifier }{ { name: "nil nonce", // otherwise assert.Equal will fail on the function @@ -317,16 +317,16 @@ func TestNewIDTokenVerifier(t *testing.T) { WithSupportedSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenVerifier{ - issuer: tu.ValidIssuer, - offset: time.Minute, - maxAgeIAT: time.Hour, - clientID: tu.ValidClientID, - keySet: tu.KeySet{}, - nonce: nil, - acr: nil, - maxAge: 2 * time.Hour, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + Offset: time.Minute, + MaxAgeIAT: time.Hour, + ClientID: tu.ValidClientID, + KeySet: tu.KeySet{}, + Nonce: nil, + ACR: nil, + MaxAge: 2 * time.Hour, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index e63e0e51..6b6945a1 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -192,7 +192,7 @@ func (j *JWTTokenRequest) GetExpiration() time.Time { // GetIssuedAt implements the Claims interface func (j *JWTTokenRequest) GetIssuedAt() time.Time { - return j.ExpiresAt.AsTime() + return j.IssuedAt.AsTime() } // GetNonce implements the Claims interface diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index cb513a09..167f8b78 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -173,10 +173,16 @@ func NewEncoder() *schema.Encoder { type Time int64 func (ts Time) AsTime() time.Time { + if ts == 0 { + return time.Time{} + } return time.Unix(int64(ts), 0) } func FromTime(tt time.Time) Time { + if tt.IsZero() { + return 0 + } return Time(tt.Unix()) } diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 2721e0b7..64f07f16 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -7,6 +7,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/gorilla/schema" "github.com/stretchr/testify/assert" @@ -467,6 +468,56 @@ func TestNewEncoder(t *testing.T) { assert.Equal(t, a, b) } +func TestTime_AsTime(t *testing.T) { + tests := []struct { + name string + ts Time + want time.Time + }{ + { + name: "unset", + ts: 0, + want: time.Time{}, + }, + { + name: "set", + ts: 1, + want: time.Unix(1, 0), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ts.AsTime() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestTime_FromTime(t *testing.T) { + tests := []struct { + name string + tt time.Time + want Time + }{ + { + name: "zero", + tt: time.Time{}, + want: 0, + }, + { + name: "set", + tt: time.Unix(1, 0), + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FromTime(tt.tt) + assert.Equal(t, tt.want, got) + }) + } +} + func TestTime_UnmarshalJSON(t *testing.T) { type dst struct { UpdatedAt Time `json:"updated_at"` diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index ad82617d..2d4e7a67 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -61,10 +61,19 @@ var ( ErrAtHash = errors.New("at_hash does not correspond to access token") ) -type Verifier interface { - Issuer() string - MaxAgeIAT() time.Duration - Offset() time.Duration +// Verifier caries configuration for the various token verification +// functions. Use package specific constructor functions to know +// which values need to be set. +type Verifier struct { + Issuer string + MaxAgeIAT time.Duration + Offset time.Duration + ClientID string + SupportedSignAlgs []string + MaxAge time.Duration + ACR ACRVerifier + KeySet KeySet + Nonce func(ctx context.Context) string } // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim @@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error { return nil } +// CheckAuthorizedParty checks azp (authorized party) claim requirements. +// +// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present. +// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value. +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func CheckAuthorizedParty(claims Claims, clientID string) error { if len(claims.GetAudience()) > 1 { if claims.GetAuthorizedParty() == "" { @@ -167,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl } func CheckExpiration(claims Claims, offset time.Duration) error { - expiration := claims.GetExpiration().Round(time.Second) - if !time.Now().UTC().Add(offset).Before(expiration) { + expiration := claims.GetExpiration() + if !time.Now().Add(offset).Before(expiration) { return ErrExpired } return nil } func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { - issuedAt := claims.GetIssuedAt().Round(time.Second) + issuedAt := claims.GetIssuedAt() if issuedAt.IsZero() { return ErrIatMissing } - nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) + nowWithOffset := time.Now().Add(offset).Round(time.Second) if issuedAt.After(nowWithOffset) { return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) } if maxAgeIAT == 0 { return nil } - maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) + maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second) if issuedAt.Before(maxAge) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) } @@ -216,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error { if claims.GetAuthTime().IsZero() { return ErrAuthTimeNotPresent } - authTime := claims.GetAuthTime().Round(time.Second) - maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) + authTime := claims.GetAuthTime() + maxAuthTime := time.Now().Add(-maxAge).Round(time.Second) if authTime.Before(maxAuthTime) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) } diff --git a/pkg/oidc/verifier_parse_test.go b/pkg/oidc/verifier_parse_test.go new file mode 100644 index 00000000..105650f0 --- /dev/null +++ b/pkg/oidc/verifier_parse_test.go @@ -0,0 +1,128 @@ +package oidc_test + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +func TestParseToken(t *testing.T) { + token, wantClaims := tu.ValidIDToken() + wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload + + wantPayload, err := json.Marshal(wantClaims) + require.NoError(t, err) + + tests := []struct { + name string + tokenString string + wantErr bool + }{ + { + name: "split error", + tokenString: "nope", + wantErr: true, + }, + { + name: "base64 error", + tokenString: "foo.~.bar", + wantErr: true, + }, + { + name: "success", + tokenString: token, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClaims := new(oidc.IDTokenClaims) + gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, wantClaims, gotClaims) + assert.JSONEq(t, string(wantPayload), string(gotPayload)) + }) + } +} + +func TestCheckSignature(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + token, _ := tu.ValidIDToken() + payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{}) + require.NoError(t, err) + + type args struct { + ctx context.Context + token string + payload []byte + supportedSigAlgs []string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "parse error", + args: args{ + ctx: context.Background(), + token: "~", + payload: payload, + }, + wantErr: oidc.ErrParse, + }, + { + name: "default sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + }, + }, + { + name: "unsupported sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + supportedSigAlgs: []string{"foo", "bar"}, + }, + wantErr: oidc.ErrSignatureUnsupportedAlg, + }, + { + name: "verify error", + args: args{ + ctx: errCtx, + token: token, + payload: payload, + }, + wantErr: oidc.ErrSignatureInvalid, + }, + { + name: "inequal payloads", + args: args{ + ctx: context.Background(), + token: token, + payload: []byte{0, 1, 2}, + }, + wantErr: oidc.ErrSignatureInvalidPayload, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := new(oidc.TokenClaims) + err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{}) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/oidc/verifier_test.go b/pkg/oidc/verifier_test.go new file mode 100644 index 00000000..93e71575 --- /dev/null +++ b/pkg/oidc/verifier_test.go @@ -0,0 +1,374 @@ +package oidc + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecryptToken(t *testing.T) { + const tokenString = "ABC" + got, err := DecryptToken(tokenString) + require.NoError(t, err) + assert.Equal(t, tokenString, got) +} + +func TestDefaultACRVerifier(t *testing.T) { + acrVerfier := DefaultACRVerifier([]string{"foo", "bar"}) + + tests := []struct { + name string + acr string + wantErr string + }{ + { + name: "ok", + acr: "bar", + }, + { + name: "error", + acr: "hello", + wantErr: "expected one of: [foo bar], got: \"hello\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := acrVerfier(tt.acr) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + }) + } +} + +func TestCheckSubject(t *testing.T) { + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrSubjectMissing, + }, + { + name: "ok", + claims: &TokenClaims{ + Subject: "foo", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSubject(tt.claims) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuer(t *testing.T) { + const issuer = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIssuerInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Issuer: "wrong", + }, + wantErr: ErrIssuerInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Issuer: issuer, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuer(tt.claims, issuer) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAudience(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrAudience, + }, + { + name: "wrong", + claims: &TokenClaims{ + Audience: []string{"wrong"}, + }, + wantErr: ErrAudience, + }, + { + name: "ok", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAudience(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizedParty(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "single audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + { + name: "multiple audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + }, + wantErr: ErrAzpMissing, + }, + { + name: "single audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + AuthorizedParty: clientID, + }, + }, + { + name: "multiple audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + AuthorizedParty: clientID, + }, + }, + { + name: "wrong azp", + claims: &TokenClaims{ + AuthorizedParty: "wrong", + }, + wantErr: ErrAzpInvalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizedParty(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckExpiration(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrExpired, + }, + { + name: "expired", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(-2 * offset)), + }, + wantErr: ErrExpired, + }, + { + name: "valid", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(2 * offset)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckExpiration(tt.claims, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuedAt(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + maxAgeIAT time.Duration + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIatMissing, + }, + { + name: "future", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(time.Hour)), + }, + wantErr: ErrIatInFuture, + }, + { + name: "no max", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + { + name: "past max", + maxAgeIAT: time.Minute, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrIatToOld, + }, + { + name: "within max", + maxAgeIAT: time.Hour, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckNonce(t *testing.T) { + const nonce = "123" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrNonceInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Nonce: "wrong", + }, + wantErr: ErrNonceInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Nonce: nonce, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckNonce(tt.claims, nonce) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizationContextClassReference(t *testing.T) { + tests := []struct { + name string + acr ACRVerifier + wantErr error + }{ + { + name: "error", + acr: func(s string) error { return errors.New("oops") }, + wantErr: ErrAcrInvalid, + }, + { + name: "ok", + acr: func(s string) error { return nil }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthTime(t *testing.T) { + tests := []struct { + name string + claims Claims + maxAge time.Duration + wantErr error + }{ + { + name: "no max age", + claims: &TokenClaims{}, + }, + { + name: "missing", + claims: &TokenClaims{}, + maxAge: time.Minute, + wantErr: ErrAuthTimeNotPresent, + }, + { + name: "expired", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrAuthTimeToOld, + }, + { + name: "ok", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: NowTime(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthTime(tt.claims, tt.maxAge) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 1d1add53..b516909b 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -38,7 +38,7 @@ type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool } @@ -47,7 +47,7 @@ type Authorizer interface { // implementing its own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error) } func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { @@ -204,7 +204,7 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi } // ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) if err != nil { return "", err @@ -384,7 +384,7 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // and returns the `sub` claim -func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) { if idTokenHint == "" { return "", nil } diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 65c65dab..3179e258 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -12,6 +12,7 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/op" @@ -146,7 +147,7 @@ func TestValidateAuthRequest(t *testing.T) { type args struct { authRequest *oidc.AuthRequest storage op.Storage - verifier op.IDTokenHintVerifier + verifier *op.IDTokenHintVerifier } tests := []struct { name string @@ -1003,3 +1004,34 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) { }) } } + +func TestValidateAuthReqIDTokenHint(t *testing.T) { + token, _ := tu.ValidIDToken() + tests := []struct { + name string + idTokenHint string + want string + wantErr error + }{ + { + name: "empty", + }, + { + name: "verify err", + idTokenHint: "foo", + wantErr: oidc.ErrLoginRequired(), + }, + { + name: "ok", + idTokenHint: token, + want: tu.ValidSubject, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{})) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/client.go b/pkg/op/client.go index 175caec2..754636cc 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -81,7 +81,7 @@ var ( ) type ClientJWTProfile interface { - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 2e40d9af..bb17192a 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -22,7 +22,7 @@ import ( type testClientJWTProfile struct{} -func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } +func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil } func TestClientJWTAuth(t *testing.T) { type args struct { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index 931b8969..a0c67e3d 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -79,10 +79,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { } // IDTokenHintVerifier mocks base method. -func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { +func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) - ret0, _ := ret[0].(op.IDTokenHintVerifier) + ret0, _ := ret[0].(*op.IDTokenHintVerifier) return ret0 } diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 6a5bdfd3..409683ab 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) { func ExpectVerifier(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( - func() op.IDTokenHintVerifier { + func() *op.IDTokenHintVerifier { return op.NewIDTokenHintVerifier("", nil) }) } diff --git a/pkg/op/op.go b/pkg/op/op.go index 27c14103..9ed5662c 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -73,8 +73,8 @@ type OpenIDProvider interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier - AccessTokenVerifier(context.Context) AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn @@ -342,15 +342,15 @@ func (o *Provider) Encoder() httphelper.Encoder { return o.encoder } -func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { +func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier { return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) } -func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { +func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) } -func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { +func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier { return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) } diff --git a/pkg/op/session.go b/pkg/op/session.go index fbce125f..fd914d11 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -13,7 +13,7 @@ import ( type SessionEnder interface { Decoder() httphelper.Decoder Storage() Storage - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string } diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 28df2175..21b79c3b 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -13,7 +13,7 @@ type Introspector interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } type IntrospectorJWTProfile interface { diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 4563e16f..4cd7b1e4 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -11,7 +11,7 @@ import ( type JWTAuthorizationGrantExchanger interface { Exchanger - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 058a2029..c06a51bc 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -20,8 +20,8 @@ type Exchanger interface { GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool GrantTypeDeviceCodeSupported() bool - AccessTokenVerifier(context.Context) AccessTokenVerifier - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index 34f8746f..fd1ee931 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -15,14 +15,14 @@ type Revoker interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier AuthMethodPrivateKeyJWTSupported() bool AuthMethodPostSupported() bool } type RevokerJWTProfile interface { Revoker - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 52a2aa20..86205b5f 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -14,7 +14,7 @@ type UserinfoProvider interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 7527ea69..120bfa71 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -2,62 +2,25 @@ package op import ( "context" - "time" "github.com/zitadel/oidc/v3/pkg/oidc" ) -type AccessTokenVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet -} - -type accessTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - keySet oidc.KeySet -} - -// Issuer implements oidc.Verifier interface -func (i *accessTokenVerifier) Issuer() string { - return i.issuer -} - -// MaxAgeIAT implements oidc.Verifier interface -func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -// Offset implements oidc.Verifier interface -func (i *accessTokenVerifier) Offset() time.Duration { - return i.offset -} - -// SupportedSignAlgs implements AccessTokenVerifier interface -func (i *accessTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -// KeySet implements AccessTokenVerifier interface -func (i *accessTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} +type AccessTokenVerifier oidc.Verifier -type AccessTokenVerifierOpt func(*accessTokenVerifier) +type AccessTokenVerifierOpt func(*AccessTokenVerifier) func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt { - return func(verifier *accessTokenVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *AccessTokenVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier { - verifier := &accessTokenVerifier{ - issuer: issuer, - keySet: keySet, +// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification. +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier { + verifier := &AccessTokenVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -66,7 +29,7 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok } // VerifyAccessToken validates the access token (issuer, signature and expiration). -func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) { +func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -78,15 +41,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go index a1972f1c..66e32ceb 100644 --- a/pkg/op/verifier_access_token_test.go +++ b/pkg/op/verifier_access_token_test.go @@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) { tests := []struct { name string args args - want AccessTokenVerifier + want *AccessTokenVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) { WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), }, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) { } func TestVerifyAccessToken(t *testing.T) { - verifier := &accessTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, + verifier := &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index 50c3ff6a..61432527 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -2,69 +2,24 @@ package op import ( "context" - "time" "github.com/zitadel/oidc/v3/pkg/oidc" ) -type IDTokenHintVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - -type idTokenHintVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - maxAge time.Duration - acr oidc.ACRVerifier - keySet oidc.KeySet -} - -func (i *idTokenHintVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenHintVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenHintVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenHintVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenHintVerifier) MaxAge() time.Duration { - return i.maxAge -} +type IDTokenHintVerifier oidc.Verifier -type IDTokenHintVerifierOpt func(*idTokenHintVerifier) +type IDTokenHintVerifierOpt func(*IDTokenHintVerifier) func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt { - return func(verifier *idTokenHintVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *IDTokenHintVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier { - verifier := &idTokenHintVerifier{ - issuer: issuer, - keySet: keySet, +func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier { + verifier := &IDTokenHintVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -74,7 +29,7 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi // VerifyIDTokenHint validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) { +func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -86,27 +41,27 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index 9f4c6c18..e514a76e 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -20,7 +20,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenHintVerifier + want *IDTokenHintVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) { WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,14 +58,14 @@ func TestNewIDTokenHintVerifier(t *testing.T) { } func TestVerifyIDTokenHint(t *testing.T) { - verifier := &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - keySet: tu.KeySet{}, + verifier := &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index b7dfec71..1daa15fc 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -11,28 +11,25 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -type JWTProfileVerifier interface { +// JWTProfileVerfiier extends oidc.Verifier with +// a jwtProfileKeyStorage and a function to check +// the subject in a token. +type JWTProfileVerifier struct { oidc.Verifier - Storage() jwtProfileKeyStorage - CheckSubject(request *oidc.JWTTokenRequest) error -} - -type jwtProfileVerifier struct { - storage jwtProfileKeyStorage - subjectCheck func(request *oidc.JWTTokenRequest) error - issuer string - maxAgeIAT time.Duration - offset time.Duration + Storage JWTProfileKeyStorage + CheckSubject func(request *oidc.JWTTokenRequest) error } // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) -func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { - j := &jwtProfileVerifier{ - storage: storage, - subjectCheck: SubjectIsIssuer, - issuer: issuer, - maxAgeIAT: maxAgeIAT, - offset: offset, +func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + j := &JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: issuer, + MaxAgeIAT: maxAgeIAT, + Offset: offset, + }, + Storage: storage, + CheckSubject: SubjectIsIssuer, } for _, opt := range opts { @@ -42,53 +39,35 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA return j } -type JWTProfileVerifierOption func(*jwtProfileVerifier) +type JWTProfileVerifierOption func(*JWTProfileVerifier) +// SubjectCheck sets a custom function to check the subject. +// Defaults to SubjectIsIssuer() func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { - return func(verifier *jwtProfileVerifier) { - verifier.subjectCheck = check + return func(verifier *JWTProfileVerifier) { + verifier.CheckSubject = check } } -func (v *jwtProfileVerifier) Issuer() string { - return v.issuer -} - -func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage { - return v.storage -} - -func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration { - return v.maxAgeIAT -} - -func (v *jwtProfileVerifier) Offset() time.Duration { - return v.offset -} - -func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error { - return v.subjectCheck(request) -} - // VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // // checks audience, exp, iat, signature and that issuer and sub are the same -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { request := new(oidc.JWTTokenRequest) payload, err := oidc.ParseToken(assertion, request) if err != nil { return nil, err } - if err = oidc.CheckAudience(request, v.Issuer()); err != nil { + if err = oidc.CheckAudience(request, v.Issuer); err != nil { return nil, err } - if err = oidc.CheckExpiration(request, v.Offset()); err != nil { + if err = oidc.CheckExpiration(request, v.Offset); err != nil { return nil, err } - if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil { return nil, err } @@ -96,17 +75,18 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer} + keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { return nil, err } return request, nil } -type jwtProfileKeyStorage interface { +type JWTProfileKeyStorage interface { GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } +// SubjectIsIssuer func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { if request.Issuer != request.Subject { return errors.New("delegation not allowed, issuer and sub must be identical") @@ -115,7 +95,7 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { } type jwtProfileKeySet struct { - storage jwtProfileKeyStorage + storage JWTProfileKeyStorage clientID string } diff --git a/pkg/op/verifier_jwt_profile_test.go b/pkg/op/verifier_jwt_profile_test.go new file mode 100644 index 00000000..d96cbb43 --- /dev/null +++ b/pkg/op/verifier_jwt_profile_test.go @@ -0,0 +1,117 @@ +package op_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +func TestNewJWTProfileVerifier(t *testing.T) { + want := &op.JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: time.Minute, + Offset: time.Second, + }, + Storage: tu.JWTProfileKeyStorage{}, + } + got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error { + return oidc.ErrSubjectMissing + })) + assert.Equal(t, want.Verifier, got.Verifier) + assert.Equal(t, want.Storage, got.Storage) + assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing) +} + +func TestVerifyJWTAssertion(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0) + tests := []struct { + name string + ctx context.Context + newToken func() (string, *oidc.JWTTokenRequest) + wantErr bool + }{ + { + name: "parse error", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil }, + wantErr: true, + }, + { + name: "wrong audience", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{"wrong"}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "expired", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now(), time.Now().Add(-time.Hour), + ) + }, + wantErr: true, + }, + { + name: "invalid iat", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now().Add(time.Hour), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "invalid subject", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, "wrong", []string{tu.ValidIssuer}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "check signature fail", + ctx: errCtx, + newToken: tu.ValidJWTProfileAssertion, + wantErr: true, + }, + { + name: "ok", + ctx: context.Background(), + newToken: tu.ValidJWTProfileAssertion, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertion, want := tt.newToken() + got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, want, got) + }) + } +}