forked from deepmap/oapi-codegen
/
jwt_authenticator.go
125 lines (108 loc) · 3.78 KB
/
jwt_authenticator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package server
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/lestrrat-go/jwx/jwt"
)
// JWSValidator is used to validate JWS payloads and return a JWT if they're
// valid
type JWSValidator interface {
ValidateJWS(jws string) (jwt.Token, error)
}
var ErrNoAuthHeader = errors.New("Authorization header is missing")
var ErrInvalidAuthHeader = errors.New("Authorization header is malformed")
var ErrClaimsInvalid = errors.New("Provided claims do not match expected scopes")
// GetJWSFromRequest extracts a JWS string from an Authorization: Bearer <jws> header
func GetJWSFromRequest(req *http.Request) (string, error) {
authHdr := req.Header.Get("Authorization")
// Check for the Authorization header.
if authHdr == "" {
return "", ErrNoAuthHeader
}
// We expect a header value of the form "Bearer <token>", with 1 space after
// Bearer, per spec.
prefix := "Bearer "
if !strings.HasPrefix(authHdr, prefix) {
return "", ErrInvalidAuthHeader
}
return strings.TrimPrefix(authHdr, prefix), nil
}
func NewAuthenticator(v JWSValidator) openapi3filter.AuthenticationFunc {
return func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
return Authenticate(v, ctx, input)
}
}
// Authenticate uses the specified validator to ensure a JWT is valid, then makes
// sure that the claims provided by the JWT match the scopes as required in the API.
func Authenticate(v JWSValidator, ctx context.Context, input *openapi3filter.AuthenticationInput) error {
// Our security scheme is named BearerAuth, ensure this is the case
if input.SecuritySchemeName != "BearerAuth" {
return fmt.Errorf("security scheme %s != 'BearerAuth'", input.SecuritySchemeName)
}
// Now, we need to get the JWS from the request, to match the request expectations
// against request contents.
jws, err := GetJWSFromRequest(input.RequestValidationInput.Request)
if err != nil {
return fmt.Errorf("getting jws: %w", err)
}
// if the JWS is valid, we have a JWT, which will contain a bunch of claims.
token, err := v.ValidateJWS(jws)
if err != nil {
return fmt.Errorf("validating JWS: %w", err)
}
// We've got a valid token now, and we can look into its claims to see whether
// they match. Every single scope must be present in the claims.
err = CheckTokenClaims(input.Scopes, token)
if err != nil {
return fmt.Errorf("token claims don't match: %w", err)
}
return nil
}
// GetClaimsFromToken returns a list of claims from the token. We store these
// as a list under the "perms" claim, short for permissions, to keep the token
// shorter.
func GetClaimsFromToken(t jwt.Token) ([]string, error) {
rawPerms, found := t.Get(PermissionsClaim)
if !found {
// If the perms aren't found, it means that the token has none, but it has
// passed signature validation by now, so it's a valid token, so we return
// the empty list.
return make([]string, 0), nil
}
// rawPerms will be an untyped JSON list, so we need to convert it to
// a string list.
rawList, ok := rawPerms.([]interface{})
if !ok {
return nil, fmt.Errorf("'%s' claim is unexpected type'", PermissionsClaim)
}
claims := make([]string, len(rawList))
for i, rawClaim := range rawList {
var ok bool
claims[i], ok = rawClaim.(string)
if !ok {
return nil, fmt.Errorf("%s[%d] is not a string", PermissionsClaim, i)
}
}
return claims, nil
}
func CheckTokenClaims(expectedClaims []string, t jwt.Token) error {
claims, err := GetClaimsFromToken(t)
if err != nil {
return fmt.Errorf("getting claims from token: %w", err)
}
// Put the claims into a map, for quick access.
claimsMap := make(map[string]bool, len(claims))
for _, c := range claims {
claimsMap[c] = true
}
for _, e := range expectedClaims {
if !claimsMap[e] {
return ErrClaimsInvalid
}
}
return nil
}