Skip to content

Commit

Permalink
add unit tests to oidc verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlemmer committed Mar 20, 2023
1 parent 9c7bcae commit d877539
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 8 deletions.
2 changes: 1 addition & 1 deletion internal/testutil/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type KeySet struct{}

// VerifySignature implments op.KeySet.
func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
if ctx.Err() != nil {
if err = ctx.Err(); err != nil {
return nil, err
}

Expand Down
19 changes: 12 additions & 7 deletions pkg/oidc/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error {
return nil
}

// CheckAuthorizedParty checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAuthorizedParty(claims Claims, clientID string) error {
if len(claims.GetAudience()) > 1 {
if claims.GetAuthorizedParty() == "" {
Expand Down Expand Up @@ -176,26 +181,26 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
}

func CheckExpiration(claims Claims, offset time.Duration) error {
expiration := claims.GetExpiration().Round(time.Second)
if !time.Now().UTC().Add(offset).Before(expiration) {
expiration := claims.GetExpiration()
if !time.Now().Add(offset).Before(expiration) {
return ErrExpired
}
return nil
}

func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
issuedAt := claims.GetIssuedAt().Round(time.Second)
issuedAt := claims.GetIssuedAt()
if issuedAt.IsZero() {
return ErrIatMissing
}
nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second)
nowWithOffset := time.Now().Add(offset).Round(time.Second)
if issuedAt.After(nowWithOffset) {
return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
}
if maxAgeIAT == 0 {
return nil
}
maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second)
maxAge := time.Now().Add(-maxAgeIAT)
if issuedAt.Before(maxAge) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
}
Expand Down Expand Up @@ -225,8 +230,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error {
if claims.GetAuthTime().IsZero() {
return ErrAuthTimeNotPresent
}
authTime := claims.GetAuthTime().Round(time.Second)
maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second)
authTime := claims.GetAuthTime()
maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
if authTime.Before(maxAuthTime) {
return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
}
Expand Down
128 changes: 128 additions & 0 deletions pkg/oidc/verifier_parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package oidc_test

import (
"context"
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tu "github.com/zitadel/oidc/v2/internal/testutil"
"github.com/zitadel/oidc/v2/pkg/oidc"
)

func TestParseToken(t *testing.T) {
token, wantClaims := tu.ValidIDToken()
wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload

wantPayload, err := json.Marshal(wantClaims)
require.NoError(t, err)

tests := []struct {
name string
tokenString string
wantErr bool
}{
{
name: "split error",
tokenString: "nope",
wantErr: true,
},
{
name: "base64 error",
tokenString: "foo.~.bar",
wantErr: true,
},
{
name: "success",
tokenString: token,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClaims := new(oidc.IDTokenClaims)
gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, wantClaims, gotClaims)
assert.JSONEq(t, string(wantPayload), string(gotPayload))
})
}
}

func TestCheckSignature(t *testing.T) {
errCtx, cancel := context.WithCancel(context.Background())
cancel()

token, _ := tu.ValidIDToken()
payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
require.NoError(t, err)

type args struct {
ctx context.Context
token string
payload []byte
supportedSigAlgs []string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "parse error",
args: args{
ctx: context.Background(),
token: "~",
payload: payload,
},
wantErr: oidc.ErrParse,
},
{
name: "default sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
},
},
{
name: "unsupported sigAlg",
args: args{
ctx: context.Background(),
token: token,
payload: payload,
supportedSigAlgs: []string{"foo", "bar"},
},
wantErr: oidc.ErrSignatureUnsupportedAlg,
},
{
name: "verify error",
args: args{
ctx: errCtx,
token: token,
payload: payload,
},
wantErr: oidc.ErrSignatureInvalid,
},
{
name: "inequal payloads",
args: args{
ctx: context.Background(),
token: token,
payload: []byte{0, 1, 2},
},
wantErr: oidc.ErrSignatureInvalidPayload,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims := new(oidc.TokenClaims)
err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
Loading

0 comments on commit d877539

Please sign in to comment.