-
Notifications
You must be signed in to change notification settings - Fork 1
/
context.go
122 lines (100 loc) · 3.61 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package rpc
import (
"context"
"encoding/base64"
"net/http"
"kon.nect.sh/specter/spec/protocol"
"kon.nect.sh/specter/spec/transport"
pool "github.com/libp2p/go-buffer-pool"
)
const (
HeaderRPCContext = "x-rpc-context"
)
type rpcContextKey string
const (
contextNodeKey = rpcContextKey("dial-node") // *protocol.Node to connect
contextRPCContextKey = rpcContextKey("rpc-context") // *protocol.Context of the rpc request
contextAuthorizationKey = rpcContextKey("auth-header") // "Authorization" header from client rpc request
contextClientTokenKey = rpcContextKey("client-token") // *protocol.ClientToken from client or parsed from the header
contextClientIdentityKey = rpcContextKey("client-identity") // *protocol.Node of client as matched with delegation and token
contextDelegationKey = rpcContextKey("stream-delegation") // *transport.StreamDelegation of the rpc request
contextDisablePoolKey = rpcContextKey("disable-http-pool") // disable HTTP client pooling. Used in test to avoid lingering connections
)
// Disable HTTP client pool for this client
func DisablePooling(baseCtx context.Context) context.Context {
return context.WithValue(baseCtx, contextDisablePoolKey, true)
}
// Connect to the provided node in this request
func WithNode(ctx context.Context, node *protocol.Node) context.Context {
return context.WithValue(ctx, contextNodeKey, node)
}
// Retrieve the node of this request
func GetNode(ctx context.Context) *protocol.Node {
if node, ok := ctx.Value(contextNodeKey).(*protocol.Node); ok {
return node
}
return nil
}
// Send RPC context for this request
func WithContext(ctx context.Context, rpcCtx *protocol.Context) context.Context {
return context.WithValue(ctx, contextRPCContextKey, rpcCtx)
}
// Retrieve the RPC context of this request
func GetContext(ctx context.Context) *protocol.Context {
if r, ok := ctx.Value(contextRPCContextKey).(*protocol.Context); ok {
return r
}
return &protocol.Context{}
}
// Attach the delegation triggering the request
func WithDelegation(ctx context.Context, delegate *transport.StreamDelegate) context.Context {
return context.WithValue(ctx, contextDelegationKey, delegate)
}
// Retrieve the delegation of this request
func GetDelegation(ctx context.Context) *transport.StreamDelegate {
if delegate, ok := ctx.Value(contextDelegationKey).(*transport.StreamDelegate); ok {
return delegate
}
return nil
}
// Serialize RPC context as http headers
func SerializeContextHeader(ctx context.Context, r http.Header) {
rCtx, ok := ctx.Value(contextRPCContextKey).(*protocol.Context)
if !ok {
return
}
l := rCtx.SizeVT()
mb := pool.Get(l)
defer pool.Put(mb)
_, err := rCtx.MarshalToSizedBufferVT(mb)
if err != nil {
return
}
r.Set(HeaderRPCContext, base64.StdEncoding.EncodeToString(mb))
}
// Deserialize RPC context from http headers. The RPC context can be retrieved with GetContext()
func DeserializeContextHeader(ctx context.Context, r http.Header) (context.Context, bool) {
encoded := r.Get(HeaderRPCContext)
if len(encoded) < 1 {
return ctx, false
}
mb, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return ctx, false
}
rCtx := &protocol.Context{}
if err := rCtx.UnmarshalVT(mb); err != nil {
return ctx, false
}
return WithContext(ctx, rCtx), true
}
// Middleware to attach the deserialized RPC context to the current request
func ExtractContext(base http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := DeserializeContextHeader(r.Context(), r.Header)
if ok {
r = r.WithContext(ctx)
}
base.ServeHTTP(w, r)
})
}