-
Notifications
You must be signed in to change notification settings - Fork 0
/
grpc_middleware.go
149 lines (121 loc) · 4.47 KB
/
grpc_middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package core
import (
"context"
"fmt"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func chainInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
for i := len(interceptors) - 1; i >= 0; i-- {
currentInterceptor := interceptors[i]
nextHandler := handler
handler = func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
return currentInterceptor(currentCtx, currentReq, info, nextHandler)
}
}
return handler(ctx, req)
}
}
func chainStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var chainHandler grpc.StreamHandler
chainHandler = func(srv interface{}, stream grpc.ServerStream) error {
return handler(srv, stream)
}
for i := len(interceptors) - 1; i >= 0; i-- {
currentInterceptor := interceptors[i]
nextHandler := chainHandler
chainHandler = func(srv interface{}, stream grpc.ServerStream) error {
return currentInterceptor(srv, stream, info, nextHandler)
}
}
return chainHandler(srv, stream)
}
}
func loggingMiddleware(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx, span := Trace(ctx, "grpc", info.FullMethod)
defer span.End()
spanContext := span.SpanContext()
spanID := spanContext.SpanID().String()
spanIDHeader := metadata.Pairs("request-id", spanID)
if err := grpc.SendHeader(ctx, spanIDHeader); err != nil {
return nil, err
}
return handler(ctx, req)
}
func loggingStreamMiddleware(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := stream.Context()
ctx, span := Trace(ctx, "grpc", info.FullMethod)
defer span.End()
spanContext := span.SpanContext()
spanID := spanContext.SpanID().String()
spanIDHeader := metadata.Pairs("request-id", spanID)
err := stream.SendHeader(spanIDHeader)
if err != nil {
return err
}
return handler(srv, &wrappedServerStream{ServerStream: stream, ctx: ctx})
}
// wrappedServerStream is a wrapper around grpc.ServerStream to override the Context method.
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}
func applySelectedMiddleware(ctx context.Context, res interface{}, selectedMiddlewares []string, middlewares map[string]Middleware) (context.Context, interface{}, error) {
var err error
for _, middleware := range selectedMiddlewares {
_, middlewareSpan := Trace(ctx, "middleware", middleware)
ctx, res, err = middlewares[middleware].Apply(ctx, res)
middlewareSpan.End()
if err != nil {
return nil, nil, err
}
}
return ctx, res, nil
}
func applyMiddleware(ctx context.Context, srv interface{}, req interface{}, info interface{}, handler interface{}, middlewares map[string]Middleware, mdConf map[string]map[string][]string) (interface{}, error) {
var err error
var res interface{} = req
middlewaresCtx, middlewaresSpan := Trace(ctx, "gprc", "middlewares")
switch h := handler.(type) {
case grpc.UnaryHandler:
methodParts := strings.Split(info.(*grpc.UnaryServerInfo).FullMethod, "/")
service := strings.Split(methodParts[1], ".")[3]
methodName := methodParts[len(methodParts)-1]
_, res, err = applySelectedMiddleware(middlewaresCtx, res, mdConf[service][methodName], middlewares)
if err != nil {
return nil, err
}
middlewaresSpan.End()
handleCtx, s := Trace(ctx, "gprc", "handler")
i, e := h(handleCtx, res)
s.End()
return i, e
case grpc.StreamHandler:
methodParts := strings.Split(info.(*grpc.StreamServerInfo).FullMethod, "/")
service := strings.Split(methodParts[1], ".")[3]
methodName := methodParts[len(methodParts)-1]
_, _, err := applySelectedMiddleware(middlewaresCtx, res, mdConf[service][methodName], middlewares)
if err != nil {
return nil, err
}
ss, ok := req.(grpc.ServerStream)
if !ok {
return nil, fmt.Errorf("expected grpc.ServerStream, got %T", req)
}
middlewaresSpan.End()
handleCtx, s := Trace(ctx, "gprc", "handler")
wrappedSS := wrapServerStream(ss, handleCtx)
err = h(srv, wrappedSS)
s.End()
return nil, err
default:
middlewaresSpan.End()
}
return nil, fmt.Errorf("Request type is not implemented")
}