This repository has been archived by the owner on Aug 28, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
intercept.go
124 lines (108 loc) · 4.61 KB
/
intercept.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package grpchan
import (
"fmt"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
// InterceptChannel returns a new channel that intercepts RPCs with the given
// interceptors. If both given interceptors are nil, returns ch.
func InterceptChannel(ch Channel, unaryInt grpc.UnaryClientInterceptor, streamInt grpc.StreamClientInterceptor) Channel {
if unaryInt == nil && streamInt == nil {
return ch
}
return &interceptedChannel{ch: ch, unaryInt: unaryInt, streamInt: streamInt}
}
type interceptedChannel struct {
ch Channel
unaryInt grpc.UnaryClientInterceptor
streamInt grpc.StreamClientInterceptor
}
func (intch *interceptedChannel) Invoke(ctx context.Context, methodName string, req, resp interface{}, opts ...grpc.CallOption) error {
if intch.unaryInt == nil {
return intch.ch.Invoke(ctx, methodName, req, resp, opts...)
}
cc, _ := intch.ch.(*grpc.ClientConn)
return intch.unaryInt(ctx, methodName, req, resp, cc, intch.unaryInvoker, opts...)
}
func (intch *interceptedChannel) unaryInvoker(ctx context.Context, methodName string, req, resp interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return intch.ch.Invoke(ctx, methodName, req, resp, opts...)
}
func (intch *interceptedChannel) NewStream(ctx context.Context, desc *grpc.StreamDesc, methodName string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
if intch.streamInt == nil {
return intch.ch.NewStream(ctx, desc, methodName, opts...)
}
cc, _ := intch.ch.(*grpc.ClientConn)
return intch.streamInt(ctx, desc, cc, methodName, intch.streamer, opts...)
}
func (intch *interceptedChannel) streamer(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, methodName string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return intch.ch.NewStream(ctx, desc, methodName, opts...)
}
var _ Channel = (*interceptedChannel)(nil)
// WithInterceptor returns a view of the given ServiceRegistry that will
// automatically apply the given interceptors to all registered services.
func WithInterceptor(reg ServiceRegistry, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) ServiceRegistry {
if unaryInt == nil && streamInt == nil {
return reg
}
return &interceptingRegistry{reg: reg, unaryInt: unaryInt, streamInt: streamInt}
}
type interceptingRegistry struct {
reg ServiceRegistry
unaryInt grpc.UnaryServerInterceptor
streamInt grpc.StreamServerInterceptor
}
func (r *interceptingRegistry) RegisterService(desc *grpc.ServiceDesc, srv interface{}) {
r.reg.RegisterService(InterceptServer(desc, r.unaryInt, r.streamInt), srv)
}
// InterceptServer returns a new service description that will intercepts RPCs
// with the given interceptors. If both given interceptors are nil, returns
// svcDesc.
func InterceptServer(svcDesc *grpc.ServiceDesc, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) *grpc.ServiceDesc {
if unaryInt == nil && streamInt == nil {
return svcDesc
}
intercepted := *svcDesc
if unaryInt != nil {
intercepted.Methods = make([]grpc.MethodDesc, len(svcDesc.Methods))
for i, md := range svcDesc.Methods {
origHandler := md.Handler
intercepted.Methods[i] = grpc.MethodDesc{
MethodName: md.MethodName,
Handler: func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
combinedInterceptor := unaryInt
if interceptor != nil {
// combine unaryInt with the interceptor provided to handler
combinedInterceptor = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
h := func(ctx context.Context, req interface{}) (interface{}, error) {
return unaryInt(ctx, req, info, handler)
}
// we first call provided interceptor, but supply a handler that will call unaryInt
return interceptor(ctx, req, info, h)
}
}
return origHandler(srv, ctx, dec, combinedInterceptor)
},
}
}
}
if streamInt != nil {
intercepted.Streams = make([]grpc.StreamDesc, len(svcDesc.Streams))
for i, sd := range svcDesc.Streams {
origHandler := sd.Handler
info := &grpc.StreamServerInfo{
FullMethod: fmt.Sprintf("/%s/%s", svcDesc.ServiceName, sd.StreamName),
IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams,
}
intercepted.Streams[i] = grpc.StreamDesc{
StreamName: sd.StreamName,
ClientStreams: sd.ClientStreams,
ServerStreams: sd.ServerStreams,
Handler: func(srv interface{}, stream grpc.ServerStream) error {
return streamInt(srv, stream, info, origHandler)
},
}
}
}
return &intercepted
}