/
auth.go
95 lines (85 loc) · 2.51 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package auth
import (
"crypto/rsa"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/lestrrat-go/jwx/jwk"
)
// Authenticator provides convenient methods for signing and validating JWT claims
type Authenticator struct {
Issuer string
Audience string
ExpiresAfter time.Duration
SignKey *rsa.PrivateKey
JwkSet jwk.Set
}
// Claims defines the interface that custom JWT claim types must implement
type Claims interface {
jwt.Claims
GetRegisteredClaims() *jwt.RegisteredClaims
}
// Validate checks a token if it is valid (e.g. has not expired)
func (auth *Authenticator) Validate(tokenString string, claims Claims) (bool, *jwt.Token, error) {
token, err := jwt.ParseWithClaims(tokenString, claims, func(t *jwt.Token) (interface{}, error) {
kid, ok := t.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("expecting JWT header kid to be string, but got %T", t.Header["kid"])
}
alg, ok := t.Header["alg"].(string)
if !ok {
return nil, fmt.Errorf("expecting JWT header alg to be string, but got %T", t.Header["alg"])
}
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("expected RSA signing method, but got %v", alg)
}
if matchingKey, ok := auth.JwkSet.LookupKeyID(kid); ok {
var key rsa.PublicKey
err := matchingKey.Raw(&key)
return &key, err
}
return nil, fmt.Errorf("unable to find key with id %q", kid)
})
if err != nil {
return false, nil, err
}
return token.Valid, token, nil
}
// SetupKeys loads or generates keys from the config
func (auth *Authenticator) SetupKeys(config *KeyConfig) error {
var errs = make([]error, 4)
if config.Jwks != "" {
auth.JwkSet, errs[0] = ParseJwkSet([]byte(config.Jwks))
}
if config.JwksFile != "" {
auth.JwkSet, errs[1] = LoadJwkSetFromFile(config.JwksFile)
}
if config.Key != "" {
auth.SignKey, errs[2] = ParseSigningKeyFromPEMData([]byte(config.Key))
}
if config.KeyFile != "" {
auth.SignKey, errs[3] = ParseSigningKeyFromPEMFile(config.KeyFile)
}
for _, err := range errs {
if err != nil {
return err
}
}
if auth.SignKey == nil || auth.JwkSet == nil {
if !config.Generate {
return errors.New("missing signing key or jwk set and --generate disabled")
}
keyPair, err := GenerateRSAKeyPair()
if err != nil {
return fmt.Errorf("failed to generate RSA key pair: %v", err)
}
auth.SignKey = keyPair.PrivateKey
jwkSet, err := ToJwks(keyPair.PublicKey)
if err != nil {
return fmt.Errorf("failed to generate JWK set: %v", err)
}
auth.JwkSet = jwkSet
}
return nil
}