/
validator.go
161 lines (140 loc) · 4.23 KB
/
validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package edgecontext
import (
"context"
"crypto/rsa"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/reddit/baseplate.go/log"
"github.com/reddit/baseplate.go/secrets"
"golang.org/x/crypto/ssh"
)
type keysType struct {
// map of kid -> pub key.
m map[string]*rsa.PublicKey
// when either kid header does not exist in the jwt token,
// or the kid is not present in the map,
// we fallback to the first (usually current) key.
first *rsa.PublicKey
}
func (kt *keysType) getKey(kid string) *rsa.PublicKey {
if key := kt.m[kid]; key != nil {
return key
}
return kt.first
}
const (
authenticationPubKeySecretPath = "secret/authentication/public-key"
jwtAlg = "RS256"
)
// JWTHeaderKeyID is the JWT header for the key id,
// as defined in RFC 7517 section 4.5.
const JWTHeaderKeyID = "kid"
// ErrNoPublicKeysLoaded is an error returned by ValidateToken indicates that
// the function is called before any public keys are loaded from secrets.
var ErrNoPublicKeysLoaded = errors.New("edgecontext.ValidateToken: no public keys loaded")
// ErrEmptyToken is an error returned by ValidateToken indicates that the JWT
// token is empty string.
var ErrEmptyToken = errors.New("edgecontext.ValidateToken: empty JWT token")
// ValidateToken parses and validates a jwt token, and return the decoded
// AuthenticationToken.
func (impl *Impl) ValidateToken(token string) (*AuthenticationToken, error) {
keys, ok := impl.keysValue.Load().(*keysType)
if !ok {
// This would only happen when all previous middleware parsing failed.
return nil, ErrNoPublicKeysLoaded
}
if token == "" {
// If we don't do the special handling here,
// jwt.ParseWithClaims below will return an error with message
// "token contains an invalid number of segments".
// Also that's still true, it's less obvious what's actually going on.
// Returning different error for empty token can also help highlighting
// other invalid tokens that actually causes that invalid number of segments
// error.
return nil, ErrEmptyToken
}
tok, err := jwt.ParseWithClaims(
token,
&AuthenticationToken{},
func(jt *jwt.Token) (interface{}, error) {
kid, _ := jt.Header[JWTHeaderKeyID].(string)
return keys.getKey(kid), nil
},
)
if err != nil {
return nil, err
}
if !tok.Valid {
return nil, jwt.NewValidationError("invalid token", 0)
}
if tok.Method.Alg() != jwtAlg {
return nil, jwt.NewValidationError("wrong signing method", 0)
}
if claims, ok := tok.Claims.(*AuthenticationToken); ok {
return claims, nil
}
return nil, jwt.NewValidationError("invalid token type", 0)
}
func (impl *Impl) validatorMiddleware(next secrets.SecretHandlerFunc) secrets.SecretHandlerFunc {
return func(sec *secrets.Secrets) {
defer next(sec)
versioned, err := sec.GetVersionedSecret(authenticationPubKeySecretPath)
if err != nil {
impl.logger.Log(context.Background(), fmt.Sprintf(
"Failed to get secrets %q: %v",
authenticationPubKeySecretPath,
err,
))
return
}
keys := parseVersionedKeys(context.Background(), versioned, impl.logger)
if keys != nil {
impl.keysValue.Store(keys)
}
}
}
func parseVersionedKeys(ctx context.Context, versioned secrets.VersionedSecret, logger log.Wrapper) *keysType {
all := versioned.GetAll()
keys := &keysType{
m: make(map[string]*rsa.PublicKey, len(all)),
}
for i, v := range all {
key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v))
if err != nil {
logger.Log(ctx, fmt.Sprintf(
"Failed to parse key #%d: %v",
i,
err,
))
} else {
if keys.first == nil {
keys.first = key
}
if fingerprint, err := RSAPublicKeyFingerprint(key); err != nil {
logger.Log(ctx, fmt.Sprintf(
"Failed to get fingerprint of key #%d: %v",
i,
err,
))
} else {
keys.m[fingerprint] = key
}
}
}
if keys.first == nil {
logger.Log(ctx, "No valid keys in secrets store.")
return nil
}
return keys
}
// RSAPublicKeyFingerprint calculates the fingerprint of an RSA public key,
// using ssh.FingerprintSHA256:
// https://pkg.go.dev/golang.org/x/crypto/ssh#FingerprintSHA256
func RSAPublicKeyFingerprint(pubKey *rsa.PublicKey) (string, error) {
key, err := ssh.NewPublicKey(pubKey)
if err != nil {
return "", err
}
return ssh.FingerprintSHA256(key), nil
}