Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: merge the verifier types #336

Merged
merged 8 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion internal/testutil/token.go
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is used only for tests. Changes here are to accommodate newly added tests.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"time"

"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc"
"gopkg.in/square/go-jose.v2"
)
Expand All @@ -17,7 +18,7 @@ type KeySet struct{}

// VerifySignature implments op.KeySet.
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
if ctx.Err() != nil {
if err = ctx.Err(); err != nil {
return nil, err
}

Expand Down Expand Up @@ -45,6 +46,16 @@ func init() {
}
}

type JWTProfileKeyStorage struct{}

func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
if err := ctx.Err(); err != nil {
return nil, err
}

return gu.Ptr(WebKey.Public()), nil
}

func signEncodeTokenClaims(claims any) string {
payload, err := json.Marshal(claims)
if err != nil {
Expand Down Expand Up @@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T
return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
}

func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
req := &oidc.JWTTokenRequest{
Issuer: issuer,
Subject: clientID,
Audience: audience,
ExpiresAt: oidc.FromTime(expiration),
IssuedAt: oidc.FromTime(issuedAt),
}
// make sure the private claim map is set correctly
data, err := json.Marshal(req)
if err != nil {
panic(err)
}
if err = json.Unmarshal(data, req); err != nil {
panic(err)
}
return signEncodeTokenClaims(req), req
}

const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`

// These variables always result in a valid token
Expand Down Expand Up @@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
}

func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
stebenz marked this conversation as resolved.
Show resolved Hide resolved
}

// ACRVerify is a oidc.ACRVerifier func.
func ACRVerify(acr string) error {
if acr != ValidACR {
Expand Down
8 changes: 4 additions & 4 deletions pkg/client/rp/relying_party.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ type RelyingParty interface {
// be used to start a DeviceAuthorization flow.
GetDeviceAuthorizationEndpoint() string

// IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier
// IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors

ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
Expand All @@ -88,7 +88,7 @@ type relyingParty struct {
cookieHandler *httphelper.CookieHandler

errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier IDTokenVerifier
idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption
signer jose.Signer
}
Expand Down Expand Up @@ -137,7 +137,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string {
return rp.endpoints.RevokeURL
}

func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier {
func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
}
Expand Down
127 changes: 35 additions & 92 deletions pkg/client/rp/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,9 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc"
)

type IDTokenVerifier interface {
oidc.Verifier
ClientID() string
SupportedSignAlgs() []string
KeySet() oidc.KeySet
Nonce(context.Context) string
ACR() oidc.ACRVerifier
MaxAge() time.Duration
}

// VerifyTokens implement the Token Response Validation as defined in OIDC specification
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) {
func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C

claims, err = VerifyIDToken[C](ctx, idToken, v)
Expand All @@ -36,7 +26,7 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str

// VerifyIDToken validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
var nilClaims C

decrypted, err := oidc.DecryptToken(token)
Expand All @@ -52,44 +42,46 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe
return nilClaims, err
}

if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
return nilClaims, err
}

if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
return nilClaims, err
}

if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
return nilClaims, err
}

if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
return nilClaims, err
}

if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
return nilClaims, err
}

if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
return nilClaims, err
}

if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nilClaims, err
}

if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
return nilClaims, err
}

if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
return nilClaims, err
}
return claims, nil
}

type IDTokenVerifier oidc.Verifier

// VerifyAccessToken validates the access token according to
// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
Expand All @@ -107,15 +99,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl
return nil
}

// NewIDTokenVerifier returns an implementation of `IDTokenVerifier`
// for `VerifyTokens` and `VerifyIDToken`
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier {
v := &idTokenVerifier{
issuer: issuer,
clientID: clientID,
keySet: keySet,
offset: time.Second,
nonce: func(_ context.Context) string {
// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
v := &IDTokenVerifier{
Issuer: issuer,
ClientID: clientID,
KeySet: keySet,
Offset: time.Second,
Nonce: func(_ context.Context) string {
return ""
},
}
Expand All @@ -128,95 +119,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...
}

// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
type VerifierOption func(*idTokenVerifier)
type VerifierOption func(*IDTokenVerifier)

// WithIssuedAtOffset mitigates the risk of iat to be in the future
// because of clock skews with the ability to add an offset to the current time
func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.offset = offset
func WithIssuedAtOffset(offset time.Duration) VerifierOption {
return func(v *IDTokenVerifier) {
v.Offset = offset
}
}

// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) {
return func(v *idTokenVerifier) {
v.maxAgeIAT = maxAge
func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
return func(v *IDTokenVerifier) {
v.MaxAgeIAT = maxAge
}
}

// WithNonce sets the function to check the nonce
func WithNonce(nonce func(context.Context) string) VerifierOption {
return func(v *idTokenVerifier) {
v.nonce = nonce
return func(v *IDTokenVerifier) {
v.Nonce = nonce
}
}

// WithACRVerifier sets the verifier for the acr claim
func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
return func(v *idTokenVerifier) {
v.acr = verifier
return func(v *IDTokenVerifier) {
v.ACR = verifier
}
}

// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
return func(v *idTokenVerifier) {
v.maxAge = maxAge
return func(v *IDTokenVerifier) {
v.MaxAge = maxAge
}
}

// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
return func(v *idTokenVerifier) {
v.supportedSignAlgs = algs
return func(v *IDTokenVerifier) {
v.SupportedSignAlgs = algs
}
}

type idTokenVerifier struct {
issuer string
maxAgeIAT time.Duration
offset time.Duration
clientID string
supportedSignAlgs []string
keySet oidc.KeySet
acr oidc.ACRVerifier
maxAge time.Duration
nonce func(ctx context.Context) string
}

func (i *idTokenVerifier) Issuer() string {
return i.issuer
}

func (i *idTokenVerifier) MaxAgeIAT() time.Duration {
return i.maxAgeIAT
}

func (i *idTokenVerifier) Offset() time.Duration {
return i.offset
}

func (i *idTokenVerifier) ClientID() string {
return i.clientID
}

func (i *idTokenVerifier) SupportedSignAlgs() []string {
return i.supportedSignAlgs
}

func (i *idTokenVerifier) KeySet() oidc.KeySet {
return i.keySet
}

func (i *idTokenVerifier) Nonce(ctx context.Context) string {
return i.nonce(ctx)
}

func (i *idTokenVerifier) ACR() oidc.ACRVerifier {
return i.acr
}

func (i *idTokenVerifier) MaxAge() time.Duration {
return i.maxAge
}
Loading