-
Notifications
You must be signed in to change notification settings - Fork 0
/
jwt.go
151 lines (134 loc) · 4.61 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package jwt
import (
"context"
"fmt"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/xqk/good/errors"
"github.com/xqk/good/middleware"
"github.com/xqk/good/transport"
)
type authKey struct{}
const (
// bearerWord the bearer key word for authorization
bearerWord string = "Bearer"
// bearerFormat authorization token format
bearerFormat string = "Bearer %s"
// authorizationKey holds the key used to store the JWT Token in the request header.
authorizationKey string = "Authorization"
)
var (
ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing")
ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing")
ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid")
ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired")
ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ")
ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method")
ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middleware")
ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing")
ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?")
ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token")
)
// Option is jwt option.
type Option func(*options)
// Parser is a jwt parser
type options struct {
signingMethod jwt.SigningMethod
claims jwt.Claims
}
// WithSigningMethod with signing method option.
func WithSigningMethod(method jwt.SigningMethod) Option {
return func(o *options) {
o.signingMethod = method
}
}
// WithClaims with customer claim
func WithClaims(claims jwt.Claims) Option {
return func(o *options) {
o.claims = claims
}
}
// Server is a server auth middleware. Check the token and extract the info from token.
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.StandardClaims{},
}
for _, opt := range opts {
opt(o)
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if header, ok := transport.FromServerContext(ctx); ok {
if keyFunc == nil {
return nil, ErrMissingKeyFunc
}
auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2)
if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) {
return nil, ErrMissingJwtToken
}
jwtToken := auths[1]
tokenInfo, err := jwt.Parse(jwtToken, keyFunc)
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return nil, ErrTokenInvalid
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
return nil, ErrTokenExpired
} else {
return nil, ErrTokenParseFail
}
}
return nil, errors.Unauthorized("UNAUTHORIZED", err.Error())
} else if !tokenInfo.Valid {
return nil, ErrTokenInvalid
} else if tokenInfo.Method != o.signingMethod {
return nil, ErrUnSupportSigningMethod
}
ctx = NewContext(ctx, tokenInfo.Claims)
return handler(ctx, req)
}
return nil, ErrWrongContext
}
}
}
// Client is a client jwt middleware.
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
o := &options{
signingMethod: jwt.SigningMethodHS256,
claims: jwt.StandardClaims{},
}
for _, opt := range opts {
opt(o)
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if keyProvider == nil {
return nil, ErrNeedTokenProvider
}
token := jwt.NewWithClaims(o.signingMethod, o.claims)
key, err := keyProvider(token)
if err != nil {
return nil, ErrGetKey
}
tokenStr, err := token.SignedString(key)
if err != nil {
return nil, ErrSignToken
}
if clientContext, ok := transport.FromClientContext(ctx); ok {
clientContext.RequestHeader().Set(authorizationKey, fmt.Sprintf(bearerFormat, tokenStr))
return handler(ctx, req)
}
return nil, ErrWrongContext
}
}
}
// NewContext put auth info into context
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
return context.WithValue(ctx, authKey{}, info)
}
// FromContext extract auth info from context
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
token, ok = ctx.Value(authKey{}).(jwt.Claims)
return
}