-
Notifications
You must be signed in to change notification settings - Fork 3
/
jwtAuth.go
154 lines (126 loc) · 3.86 KB
/
jwtAuth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package interceptor
import (
"context"
"github.com/zhufuyi/pkg/jwt"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// ---------------------------------- server interceptor ----------------------------------
var (
// auth Scheme
authScheme = "Bearer"
// 鉴权信息在ctx中key名
authCtxClaimsName = "tokenInfo"
// 跳过认证方法集合
authIgnoreMethods = map[string]struct{}{}
)
// AuthOption 设置鉴权字段
type AuthOption func(*AuthOptions)
// AuthOptions 鉴权设置
type AuthOptions struct {
authScheme string
ctxClaimsName string
ignoreMethods map[string]struct{}
}
func defaultAuthOptions() *AuthOptions {
return &AuthOptions{
authScheme: authScheme,
ctxClaimsName: authCtxClaimsName,
ignoreMethods: make(map[string]struct{}), // 忽略鉴权的方法
}
}
func (o *AuthOptions) apply(opts ...AuthOption) {
for _, opt := range opts {
opt(o)
}
}
// WithAuthScheme 设置鉴权的信息前缀
func WithAuthScheme(scheme string) AuthOption {
return func(o *AuthOptions) {
o.authScheme = scheme
}
}
// WithAuthClaimsName 设置鉴权的信息在ctx的key名称
func WithAuthClaimsName(claimsName string) AuthOption {
return func(o *AuthOptions) {
o.ctxClaimsName = claimsName
}
}
// WithAuthIgnoreMethods 忽略鉴权的方法
// fullMethodName格式: /packageName.serviceName/methodName,
// 示例/api.userExample.v1.userExampleService/GetByID
func WithAuthIgnoreMethods(fullMethodNames ...string) AuthOption {
return func(o *AuthOptions) {
for _, method := range fullMethodNames {
o.ignoreMethods[method] = struct{}{}
}
}
}
// GetAuthorization 根据token组合成鉴权信息
func GetAuthorization(token string) string {
return authScheme + " " + token
}
// GetAuthCtxKey 获取Claims的名称
func GetAuthCtxKey() string {
return authCtxClaimsName
}
// JwtVerify 从context获取authorization来验证是否合法,authorization组成格式:authScheme token
func JwtVerify(ctx context.Context) (context.Context, error) {
token, err := grpc_auth.AuthFromMD(ctx, authScheme)
if err != nil {
return nil, err
}
cc, err := jwt.VerifyToken(token)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%v", err)
}
newCtx := context.WithValue(ctx, authCtxClaimsName, cc) //nolint 后面方法可以通过ctx.Value(interceptor.GetAuthCtxKey()).(*jwt.CustomClaims)
return newCtx, nil
}
// UnaryServerJwtAuth jwt鉴权unary拦截器
func UnaryServerJwtAuth(opts ...AuthOption) grpc.UnaryServerInterceptor {
o := defaultAuthOptions()
o.apply(opts...)
authScheme = o.authScheme
authCtxClaimsName = o.ctxClaimsName
authIgnoreMethods = o.ignoreMethods
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var newCtx context.Context
var err error
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
newCtx = ctx
} else {
newCtx, err = JwtVerify(ctx)
if err != nil {
return nil, err
}
}
return handler(newCtx, req)
}
}
// StreamServerJwtAuth jwt鉴权stream拦截器
func StreamServerJwtAuth(opts ...AuthOption) grpc.StreamServerInterceptor {
o := defaultAuthOptions()
o.apply(opts...)
authScheme = o.authScheme
authCtxClaimsName = o.ctxClaimsName
authIgnoreMethods = o.ignoreMethods
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var newCtx context.Context
var err error
if _, ok := authIgnoreMethods[info.FullMethod]; ok {
newCtx = stream.Context()
} else {
newCtx, err = JwtVerify(stream.Context())
if err != nil {
return err
}
}
wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx
return handler(srv, wrapped)
}
}