Skip to content

Commit

Permalink
jwk: update RS256JWTStrategy to adhere to the new interface
Browse files Browse the repository at this point in the history
Signed-off-by: Amir Aslaminejad <aslaminejad@gmail.com>
  • Loading branch information
aaslamin authored and aeneasr committed Sep 21, 2018
1 parent 5dda1a2 commit a190bee
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
38 changes: 19 additions & 19 deletions jwk/jwt_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
)

type JWTStrategy interface {
GetPublicKeyID() (string, error)
GetPublicKeyID(ctx context.Context) (string, error)

jwt.JWTStrategy
}
Expand All @@ -43,7 +43,7 @@ func NewRS256JWTStrategy(m Manager, set string) (*RS256JWTStrategy, error) {
RS256JWTStrategy: &jwt.RS256JWTStrategy{},
Set: set,
}
if err := j.refresh(); err != nil {
if err := j.refresh(context.TODO()); err != nil {
return nil, err
}
return j, nil
Expand All @@ -60,53 +60,53 @@ type RS256JWTStrategy struct {
privateKeyID string
}

func (j *RS256JWTStrategy) Hash(in []byte) ([]byte, error) {
return j.RS256JWTStrategy.Hash(in)
func (j *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
return j.RS256JWTStrategy.Hash(ctx, in)
}

// GetSigningMethodLength will return the length of the signing method
func (j *RS256JWTStrategy) GetSigningMethodLength() int {
return j.RS256JWTStrategy.GetSigningMethodLength()
}

func (j *RS256JWTStrategy) GetSignature(token string) (string, error) {
return j.RS256JWTStrategy.GetSignature(token)
func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
return j.RS256JWTStrategy.GetSignature(ctx, token)
}

func (j *RS256JWTStrategy) Generate(claims jwt2.Claims, header jwt.Mapper) (string, string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt2.Claims, header jwt.Mapper) (string, string, error) {
if err := j.refresh(ctx); err != nil {
return "", "", err
}

return j.RS256JWTStrategy.Generate(claims, header)
return j.RS256JWTStrategy.Generate(ctx, claims, header)
}

func (j *RS256JWTStrategy) Validate(token string) (string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
if err := j.refresh(ctx); err != nil {
return "", err
}

return j.RS256JWTStrategy.Validate(token)
return j.RS256JWTStrategy.Validate(ctx, token)
}

func (j *RS256JWTStrategy) Decode(token string) (*jwt2.Token, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt2.Token, error) {
if err := j.refresh(ctx); err != nil {
return nil, err
}

return j.RS256JWTStrategy.Decode(token)
return j.RS256JWTStrategy.Decode(ctx, token)
}

func (j *RS256JWTStrategy) GetPublicKeyID() (string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) GetPublicKeyID(ctx context.Context) (string, error) {
if err := j.refresh(ctx); err != nil {
return "", err
}

return j.publicKeyID, nil
}

func (j *RS256JWTStrategy) refresh() error {
keys, err := j.Manager.GetKeySet(context.TODO(), j.Set)
func (j *RS256JWTStrategy) refresh(ctx context.Context) error {
keys, err := j.Manager.GetKeySet(ctx, j.Set)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions jwk/jwt_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,31 @@ func TestRS256JWTStrategy(t *testing.T) {

s, err := NewRS256JWTStrategy(m, "foo-set")
require.NoError(t, err)
a, b, err := s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
a, b, err := s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
require.NoError(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, b)

_, err = s.Validate(a)
_, err = s.Validate(context.TODO(), a)
require.NoError(t, err)

kid, err := s.GetPublicKeyID()
kid, err := s.GetPublicKeyID(context.TODO())
assert.NoError(t, err)
assert.Equal(t, "public:foo", kid)

ks, err = testGenerator.Generate("bar", "sig")
require.NoError(t, err)
require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks))

a, b, err = s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
a, b, err = s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
require.NoError(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, b)

_, err = s.Validate(a)
_, err = s.Validate(context.TODO(), a)
require.NoError(t, err)

kid, err = s.GetPublicKeyID()
kid, err = s.GetPublicKeyID(context.TODO())
assert.NoError(t, err)
assert.Equal(t, "public:bar", kid)
}

0 comments on commit a190bee

Please sign in to comment.