/
dial.go
438 lines (401 loc) · 12.8 KB
/
dial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
package rpc
import (
"context"
"crypto/tls"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/edaniels/golog"
"github.com/edaniels/zeroconf"
"github.com/pkg/errors"
"go.uber.org/zap"
)
// Dial attempts to make the most convenient connection to the given address. It attempts to connect
// via WebRTC if a signaling server is detected or provided. Otherwise it attempts to connect directly.
// TODO(GOUT-7): figure out decent way to handle reconnect on connection termination.
func Dial(ctx context.Context, address string, logger golog.Logger, opts ...DialOption) (ClientConn, error) {
var dOpts dialOptions
for _, opt := range opts {
opt.apply(&dOpts)
}
if logger == nil {
logger = zap.NewNop().Sugar()
}
return dialInner(ctx, address, logger, dOpts)
}
func dialInner(
ctx context.Context,
address string,
logger golog.Logger,
dOpts dialOptions,
) (ClientConn, error) {
if address == "" {
return nil, errors.New("address empty")
}
conn, cached, err := dialFunc(
ctx,
"multi",
address,
dOpts.cacheKey(),
func() (ClientConn, error) {
if dOpts.debug {
logger.Debugw("starting to dial", "address", address)
}
if dOpts.authEntity == "" {
if dOpts.externalAuthAddr == "" {
// if we are not doing external auth, then the entity is assumed to be the actual address.
if dOpts.debug {
logger.Debugw("auth entity empty; setting to address", "address", address)
}
dOpts.authEntity = address
} else {
// otherwise it's the external auth address.
if dOpts.debug {
logger.Debugw("auth entity empty; setting to external auth address", "address", dOpts.externalAuthAddr)
}
dOpts.authEntity = dOpts.externalAuthAddr
}
}
conn, _, err := dial(ctx, address, address, logger, dOpts, true)
return conn, err
})
if err != nil {
return nil, err
}
if cached {
if dOpts.debug {
logger.Debugw("connected (cached)", "address", address)
}
}
return conn, nil
}
// ErrConnectionOptionsExhausted is returned in cases where the given options have all been used up with
// no way to connect on any of them.
var ErrConnectionOptionsExhausted = errors.New("exhausted all connection options with no way to connect")
// dialResult contains information about a concurrent dial attempt.
type dialResult struct {
// a successfully established connection
conn ClientConn
// whether or not the connection is reused
cached bool
// connection errors
err error
// whether we should skip dialing gRPC directly as a fallback
skipDirect bool
}
func dial(
ctx context.Context,
address string,
originalAddress string,
logger golog.Logger,
dOpts dialOptions,
tryLocal bool,
) (ClientConn, bool, error) {
if ctx.Err() != nil {
return nil, false, ctx.Err()
}
var isJustDomain bool
switch {
case strings.HasPrefix(address, "unix://"):
dOpts.mdnsOptions.Disable = true
dOpts.webrtcOpts.Disable = true
dOpts.insecure = true
dOpts.disableDirect = false
case strings.ContainsRune(address, ':'):
isJustDomain = false
default:
isJustDomain = net.ParseIP(address) == nil
}
// We make concurrent dial attempts via mDNS and WebRTC, taking the first connection
// that succeeds. We then cancel the slower connection and wait for its coroutine to
// complete. If the slower connection succeeds before it can be cancelled then we
// explicitly close it to prevent a memory leak.
var (
wg sync.WaitGroup
dialCh = make(chan dialResult)
ctxParallel, cancelParallel = context.WithCancelCause(ctx)
)
defer cancelParallel(nil)
if !dOpts.mdnsOptions.Disable && tryLocal && isJustDomain {
wg.Add(1)
go func(dOpts dialOptions) {
defer wg.Done()
conn, cached, err := dialMulticastDNS(ctxParallel, address, logger, dOpts)
if err != nil {
dialCh <- dialResult{err: err}
} else {
dialCh <- dialResult{conn: conn, cached: cached}
}
}(dOpts)
}
if !dOpts.webrtcOpts.Disable {
wg.Add(1)
go func(dOpts dialOptions) {
defer wg.Done()
signalingAddress := dOpts.webrtcOpts.SignalingServerAddress
if signalingAddress == "" || dOpts.webrtcOpts.AllowAutoDetectAuthOptions {
if signalingAddress == "" {
// try WebRTC at same address
signalingAddress = address
}
target, port, err := getWebRTCTargetFromAddressWithDefaults(signalingAddress)
if err != nil {
// TODO(RSDK-6493): Investigate if we must `skipDirect` here.
dialCh <- dialResult{err: err, skipDirect: true}
return
}
dOpts.fixupWebRTCOptions(target, port)
// When connecting to an external signaler for WebRTC, we assume we can use the external auth's material.
// This path is also called by an mdns direct connection and ignores that case.
// This will skip all Authenticate/AuthenticateTo calls for the signaler.
if !dOpts.usingMDNS && dOpts.authMaterial == "" && dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial != "" {
logger.Debug("using signaling's external auth as auth material")
dOpts.authMaterial = dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial
dOpts.creds = Credentials{}
}
}
if dOpts.debug {
logger.Debugw(
"trying WebRTC",
"signaling_server", dOpts.webrtcOpts.SignalingServerAddress,
"host", originalAddress,
)
}
conn, cached, err := dialFunc(
ctxParallel,
"webrtc",
fmt.Sprintf("%s->%s", dOpts.webrtcOpts.SignalingServerAddress, originalAddress),
dOpts.cacheKey(),
func() (ClientConn, error) {
return dialWebRTC(
ctxParallel,
dOpts.webrtcOpts.SignalingServerAddress,
originalAddress,
dOpts,
logger,
)
})
switch {
case err == nil:
if dOpts.debug {
logger.Debugw("connected via WebRTC",
"address", address,
"cached", cached,
"using mDNS", dOpts.usingMDNS,
)
}
dialCh <- dialResult{conn: conn, cached: cached}
case !errors.Is(err, ErrNoWebRTCSignaler):
// TODO(RSDK-6493): Investigate if we must `skipDirect` here.
dialCh <- dialResult{err: err, skipDirect: true}
case ctxParallel.Err() != nil:
dialCh <- dialResult{err: ctxParallel.Err(), skipDirect: true}
default:
dialCh <- dialResult{err: err}
}
}(dOpts)
}
// Make sure the slower connection attempt is fully cancelled, or if the attempt succeeded,
// close the slower connection.
go func() {
wg.Wait()
close(dialCh)
}()
var (
conn ClientConn
cached bool
err error
)
for result := range dialCh {
switch {
case result.err == nil && result.conn != nil:
if conn != nil {
errClose := conn.Close()
if errClose != nil {
logger.Warnw("unable to close redundant connection", "error", err)
}
}
conn, cached = result.conn, result.cached
cancelParallel(errors.New("using another established connection"))
case result.err != nil && result.skipDirect:
err = result.err
}
}
if conn != nil {
return conn, cached, nil
}
if err != nil {
return nil, false, err
}
if dOpts.disableDirect {
return nil, false, ErrConnectionOptionsExhausted
}
if dOpts.debug {
logger.Debugw("trying direct", "address", address)
}
conn, cached, err = dialDirectGRPC(ctx, address, dOpts, logger)
if err != nil {
return nil, false, err
}
if dOpts.debug {
logger.Debugw("connected via gRPC",
"address", address,
"cached", cached,
"using mDNS", dOpts.usingMDNS,
)
}
return conn, cached, nil
}
func lookupMDNSCandidate(ctx context.Context, address string, logger golog.Logger) (*zeroconf.ServiceEntry, error) {
candidates := []string{address, strings.ReplaceAll(address, ".", "-")}
resolver, err := zeroconf.NewResolver(logger, zeroconf.SelectIPRecordType(zeroconf.IPv4))
if err != nil {
return nil, err
}
defer resolver.Shutdown()
for _, candidate := range candidates {
entries := make(chan *zeroconf.ServiceEntry)
lookupCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
if err := resolver.Lookup(lookupCtx, candidate, "_rpc._tcp", "local.", entries); err != nil {
logger.Errorw("error performing mDNS query", "error", err)
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
// entries gets closed after lookupCtx expires or there is a real entry
case entry := <-entries:
if entry != nil {
return entry, nil
}
}
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, errors.New("mDNS query failed to find a candidate")
}
func dialMulticastDNS(
ctx context.Context,
address string,
logger golog.Logger,
dOpts dialOptions,
) (ClientConn, bool, error) {
entry, err := lookupMDNSCandidate(ctx, address, logger)
if err != nil {
return nil, false, err
}
var hasGRPC, hasWebRTC bool
for _, field := range entry.Text {
// mdns service may advertise TXT field following https://datatracker.ietf.org/doc/html/rfc1464 (ex grpc=)
if strings.Contains(field, "grpc") {
hasGRPC = true
}
if strings.Contains(field, "webrtc") {
hasWebRTC = true
}
}
// IPv6 with scope does not work with grpc-go which we would want here.
if !(hasGRPC || hasWebRTC) || len(entry.AddrIPv4) == 0 {
errMsg := `mDNS query found a service without an IPv4 address that does not support grpc or webrtc: %q`
return nil, false, fmt.Errorf(errMsg, entry.ServiceName())
}
localAddress := fmt.Sprintf("%s:%d", entry.AddrIPv4[0], entry.Port)
if dOpts.debug {
logger.Debugw("found address via mDNS", "address", localAddress)
}
// Let downstream calls know when mdns was used. This is helpful to inform
// when determining if we want to use the external auth credentials for the signaling
// in cases where the external signaling is the same as the external auth. For mdns
// this isn't the case.
dOpts.usingMDNS = true
if dOpts.mdnsOptions.RemoveAuthCredentials {
dOpts.creds = Credentials{}
dOpts.authEntity = ""
dOpts.externalAuthToEntity = ""
dOpts.externalAuthMaterial = ""
}
if hasWebRTC {
dOpts.fixupWebRTCOptions(entry.AddrIPv4[0].String(), uint16(entry.Port))
if dOpts.mdnsOptions.RemoveAuthCredentials {
dOpts.webrtcOpts.SignalingAuthEntity = ""
dOpts.webrtcOpts.SignalingCreds = Credentials{}
dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial = ""
}
} else {
dOpts.webrtcOpts.Disable = true
}
var tlsConfig *tls.Config
if dOpts.tlsConfig == nil {
tlsConfig = newDefaultTLSConfig()
} else {
tlsConfig = dOpts.tlsConfig.Clone()
}
tlsConfig.ServerName = address
dOpts.tlsConfig = tlsConfig
conn, cached, err := dial(ctx, localAddress, address, logger, dOpts, false)
if err == nil {
return conn, cached, nil
}
return nil, false, err
}
// fixupWebRTCOptions sets sensible and secure settings for WebRTC dial options based on
// auto detection / connection attempts as well as what settings are not set and can be interpreted
// from non WebRTC dial options (e.g. credentials becoming signaling credentials).
func (dOpts *dialOptions) fixupWebRTCOptions(target string, port uint16) {
dOpts.webrtcOpts.SignalingServerAddress = fmt.Sprintf("%s:%d", target, port)
if !dOpts.webrtcOptsSet {
dOpts.webrtcOpts.SignalingInsecure = dOpts.insecure
dOpts.webrtcOpts.SignalingExternalAuthInsecure = dOpts.externalAuthInsecure
}
if dOpts.webrtcOpts.SignalingExternalAuthAddress == "" {
dOpts.webrtcOpts.SignalingExternalAuthAddress = dOpts.externalAuthAddr
}
if dOpts.webrtcOpts.SignalingExternalAuthToEntity == "" {
dOpts.webrtcOpts.SignalingExternalAuthToEntity = dOpts.externalAuthToEntity
}
if dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial == "" {
dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial = dOpts.externalAuthMaterial
}
// It's always okay to pass over entity and credentials since next section
// will assume secure settings based on public internet or not.
// The security considerations are as follows:
// 1. from mDNS - follows insecure downgrade rules and server name TLS check
// stays in tact, so we are transferring credentials to the same host or
// user says they do not care.
// 2. from trying WebRTC when signaling address not explicitly set - follows
// insecure downgrade rules and host/target stays in tact, so we are transferring
// credentials to the same host or user says they do not care.
// 3. form user explicitly allowing this.
if dOpts.webrtcOpts.SignalingAuthEntity == "" {
dOpts.webrtcOpts.SignalingAuthEntity = dOpts.authEntity
}
if dOpts.webrtcOpts.SignalingCreds.Type == "" {
dOpts.webrtcOpts.SignalingCreds = dOpts.creds
}
}
func getWebRTCTargetFromAddressWithDefaults(signalingAddress string) (target string, port uint16, err error) {
if strings.Contains(signalingAddress, ":") {
host, portStr, err := net.SplitHostPort(signalingAddress)
if err != nil {
return "", 0, err
}
if strings.Contains(host, ":") {
host = fmt.Sprintf("[%s]", host)
}
target = host
portParsed, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, err
}
port = uint16(portParsed)
} else {
target = signalingAddress
port = 443
}
return target, port, nil
}