-
-
Notifications
You must be signed in to change notification settings - Fork 89
/
dataconn_server.go
256 lines (225 loc) · 7.1 KB
/
dataconn_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
package dataconn
import (
"bytes"
"context"
"fmt"
"io"
"sync"
"google.golang.org/protobuf/proto"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/replication/logic/pdu"
"github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/transport"
)
// WireInterceptor has a chance to exchange the context and connection on each client connection.
type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (context.Context, *transport.AuthConn)
// Handler implements the functionality that is exposed by Server to the Client.
type Handler interface {
// Send handles a SendRequest.
// The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run.
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
// Receive handles a ReceiveRequest.
// It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout
// configured in ServerConfig.Shared.IdleConnTimeout.
Receive(ctx context.Context, r *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error)
// PingDataconn handles a PingReq
PingDataconn(ctx context.Context, r *pdu.PingReq) (*pdu.PingRes, error)
}
type Logger = logger.Logger
type ContextInterceptorData interface {
FullMethod() string
ClientIdentity() string
}
type ContextInterceptor = func(ctx context.Context, data ContextInterceptorData, handler func(ctx context.Context))
type Server struct {
h Handler
wi WireInterceptor
ci ContextInterceptor
log Logger
}
var noopContextInteceptor = func(ctx context.Context, _ ContextInterceptorData, handler func(context.Context)) {
handler(ctx)
}
// wi and ci may be nil
func NewServer(wi WireInterceptor, ci ContextInterceptor, logger Logger, handler Handler) *Server {
if ci == nil {
ci = noopContextInteceptor
}
return &Server{
h: handler,
wi: wi,
ci: ci,
log: logger,
}
}
// Serve consumes the listener, closes it as soon as ctx is closed.
// No accept errors are returned: they are logged to the Logger passed
// to the constructor.
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
var wg sync.WaitGroup
defer wg.Wait()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
s.log.Debug("context done, closing listener")
if err := l.Close(); err != nil {
s.log.WithError(err).Error("cannot close listener")
}
}()
conns := make(chan *transport.AuthConn)
wg.Add(1)
go func() {
defer wg.Done()
defer close(conns)
for {
conn, err := l.Accept(ctx)
if err != nil {
if ctx.Done() != nil {
s.log.Debug("stop accepting after context is done")
return
}
s.log.WithError(err).Error("accept error")
continue
}
conns <- conn
}
}()
for conn := range conns {
wg.Add(1)
go func(conn *transport.AuthConn) {
defer wg.Done()
s.serveConn(conn)
}(conn)
}
}
type contextInterceptorData struct {
fullMethod string
clientIdentity string
}
func (d contextInterceptorData) FullMethod() string { return d.fullMethod }
func (d contextInterceptorData) ClientIdentity() string { return d.clientIdentity }
func (s *Server) serveConn(nc *transport.AuthConn) {
s.log.Debug("serveConn begin")
defer s.log.Debug("serveConn done")
ctx := context.Background()
if s.wi != nil {
ctx, nc = s.wi(ctx, nc)
}
c := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout)
defer func() {
s.log.Debug("close client connection")
if err := c.Close(); err != nil {
s.log.WithError(err).Error("cannot close client connection")
}
}()
header, err := c.ReadStreamedMessage(ctx, RequestHeaderMaxSize, ReqHeader)
if err != nil {
s.log.WithError(err).Error("error reading structured part")
return
}
endpoint := string(header)
data := contextInterceptorData{
fullMethod: endpoint,
clientIdentity: nc.ClientIdentity(),
}
s.ci(ctx, data, func(ctx context.Context) {
s.serveConnRequest(ctx, endpoint, c)
})
}
func (s *Server) serveConnRequest(ctx context.Context, endpoint string, c *stream.Conn) {
reqStructured, err := c.ReadStreamedMessage(ctx, RequestStructuredMaxSize, ReqStructured)
if err != nil {
s.log.WithError(err).Error("error reading structured part")
return
}
s.log.WithField("endpoint", endpoint).Debug("calling handler")
var res proto.Message
var sendStream io.ReadCloser
var handlerErr error
switch endpoint {
case EndpointSend:
var req pdu.SendReq
if err := proto.Unmarshal(reqStructured, &req); err != nil {
s.log.WithError(err).Error("cannot unmarshal send request")
return
}
res, sendStream, handlerErr = s.h.Send(ctx, &req) // SHADOWING
// ensure that we always close the sendStream
if sendStream != nil {
defer func() {
err := sendStream.Close()
if err != nil {
s.log.WithError(err).Error("cannot close send stream")
}
}()
}
case EndpointRecv:
var req pdu.ReceiveReq
if err := proto.Unmarshal(reqStructured, &req); err != nil {
s.log.WithError(err).Error("cannot unmarshal receive request")
return
}
stream, err := c.ReadStream(ZFSStream, false)
if err != nil {
s.log.WithError(err).Error("cannot open stream in receive request")
return
}
res, handlerErr = s.h.Receive(ctx, &req, stream) // SHADOWING
case EndpointPing:
var req pdu.PingReq
if err := proto.Unmarshal(reqStructured, &req); err != nil {
s.log.WithError(err).Error("cannot unmarshal ping request")
return
}
res, handlerErr = s.h.PingDataconn(ctx, &req) // SHADOWING
default:
s.log.WithField("endpoint", endpoint).Error("unknown endpoint")
handlerErr = fmt.Errorf("requested endpoint does not exist")
}
s.log.WithField("endpoint", endpoint).WithField("errType", fmt.Sprintf("%T", handlerErr)).Debug("handler returned")
// prepare protobuf now to return the protobuf error in the header
// if marshaling fails. We consider failed marshaling a handler error
var protobuf *bytes.Buffer
if handlerErr == nil {
if res == nil {
handlerErr = fmt.Errorf("implementation error: handler for endpoint %q returns nil error and nil result", endpoint)
s.log.WithError(err).Error("handle implementation error")
} else {
protobufBytes, err := proto.Marshal(res)
if err != nil {
s.log.WithError(err).Error("cannot marshal handler protobuf")
handlerErr = err
}
protobuf = bytes.NewBuffer(protobufBytes) // SHADOWING
}
}
var resHeaderBuf bytes.Buffer
if handlerErr == nil {
resHeaderBuf.WriteString(responseHeaderHandlerOk)
} else {
resHeaderBuf.WriteString(responseHeaderHandlerErrorPrefix)
resHeaderBuf.WriteString(handlerErr.Error())
}
if err := c.WriteStreamedMessage(ctx, &resHeaderBuf, ResHeader); err != nil {
s.log.WithError(err).Error("cannot write response header")
return
}
if handlerErr != nil {
s.log.Debug("early exit after handler error")
return
}
if err := c.WriteStreamedMessage(ctx, protobuf, ResStructured); err != nil {
s.log.WithError(err).Error("cannot write structured part of response")
return
}
if sendStream != nil {
err := c.SendStream(ctx, sendStream, ZFSStream)
if err != nil {
s.log.WithError(err).Error("cannot write send stream")
}
// sendStream.Close() done via defer above
}
}