/
client_interceptor.go
82 lines (71 loc) · 3.09 KB
/
client_interceptor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package middleware
import (
"context"
"time"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/pudongping/go-grpc-service/global"
"github.com/pudongping/go-grpc-service/pkg/metatext"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// 超时控制(上下文)
func defaultContextTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
var cancel context.CancelFunc
// 通过对传入的 context 调用 ctx.Deadline 方法进行检查,若未设置截止时间的话,其将会返回 false,
// 那么我们就会对其调用 context.WithTimeout 方法设置默认超时时间为 60 秒
// (该超时时间设置是针对整条调用链路的,若需要另外调整,可在应用代码中再自行调整)
if _, ok := ctx.Deadline(); !ok {
defaultTimeout := 60 * time.Second
ctx, cancel = context.WithTimeout(ctx, defaultTimeout)
}
return ctx, cancel
}
// UnaryContextTimeout 一元调用客户端拦截器
func UnaryContextTimeout() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, cancel := defaultContextTimeout(ctx)
if cancel != nil {
defer cancel()
}
return invoker(ctx, method, req, resp, cc, opts...)
}
}
// StreamContextTimeout 流式调用客户端拦截器
func StreamContextTimeout() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx, cancel := defaultContextTimeout(ctx)
if cancel != nil {
defer cancel()
}
return streamer(ctx, desc, cc, method, opts...)
}
}
// ClientInterceptor grpc client wrapper
func ClientTracing() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var parentCtx opentracing.SpanContext
var spanOpts []opentracing.StartSpanOption
// 解析上下文信息
var parentSpan = opentracing.SpanFromContext(ctx)
// 检查其是否包含上一级的跨度信息。若存在,则获取上一级的上下文信息,把它作为接下来本次跨度的父级
if parentSpan != nil {
parentCtx = parentSpan.Context()
spanOpts = append(spanOpts, opentracing.ChildOf(parentCtx))
}
spanOpts = append(spanOpts, []opentracing.StartSpanOption{
opentracing.Tag{Key: string(ext.Component), Value: "gRPC"},
ext.SpanKindRPCClient,
}...)
span := global.Tracer.StartSpan(method, spanOpts...)
defer span.Finish()
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
md = metadata.New(nil)
}
// 对传出的 md 信息进行转换,把它设置到新的上下文信息中,以便后续在调用时使用
_ = global.Tracer.Inject(span.Context(), opentracing.TextMap, metatext.MetadataTextMap{md})
newCtx := opentracing.ContextWithSpan(metadata.NewOutgoingContext(ctx, md), span)
return invoker(newCtx, method, req, reply, cc, opts...)
}
}