/
packet_conn.go
407 lines (355 loc) · 11.4 KB
/
packet_conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package udp implements DTLS specific UDP networking primitives.
// NOTE: this package is an adaption of pion/transport/udp that allows for
// routing datagrams based on identifiers other than the remote address. The
// primary use case for this functionality is routing based on DTLS connection
// IDs. In order to allow for consumers of this package to treat connections as
// generic net.PackageConn, routing and identitier establishment is based on
// custom introspecion of datagrams, rather than direct intervention by
// consumers. If possible, the updates made in this repository will be reflected
// back upstream. If not, it is likely that this will be moved to a public
// package in this repository.
//
// This package was migrated from pion/transport/udp at
// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb
package udp
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
idtlsnet "github.com/pion/dtls/v2/internal/net"
dtlsnet "github.com/pion/dtls/v2/pkg/net"
"github.com/pion/transport/v3/deadline"
)
const (
receiveMTU = 8192
defaultListenBacklog = 128 // same as Linux default
)
// Typed errors
var (
ErrClosedListener = errors.New("udp: listener closed")
ErrListenQueueExceeded = errors.New("udp: listen queue exceeded")
)
// listener augments a connection-oriented Listener over a UDP PacketConn
type listener struct {
pConn *net.UDPConn
accepting atomic.Value // bool
acceptCh chan *PacketConn
doneCh chan struct{}
doneOnce sync.Once
acceptFilter func([]byte) bool
datagramRouter func([]byte) (string, bool)
connIdentifier func([]byte) (string, bool)
connLock sync.Mutex
conns map[string]*PacketConn
connWG sync.WaitGroup
readWG sync.WaitGroup
errClose atomic.Value // error
readDoneCh chan struct{}
errRead atomic.Value // error
}
// Accept waits for and returns the next connection to the listener.
func (l *listener) Accept() (net.PacketConn, net.Addr, error) {
select {
case c := <-l.acceptCh:
l.connWG.Add(1)
return c, c.raddr, nil
case <-l.readDoneCh:
err, _ := l.errRead.Load().(error)
return nil, nil, err
case <-l.doneCh:
return nil, nil, ErrClosedListener
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *listener) Close() error {
var err error
l.doneOnce.Do(func() {
l.accepting.Store(false)
close(l.doneCh)
l.connLock.Lock()
// Close unaccepted connections
lclose:
for {
select {
case c := <-l.acceptCh:
close(c.doneCh)
// If we have an alternate identifier, remove it from the connection
// map.
if id := c.id.Load(); id != nil {
delete(l.conns, id.(string)) //nolint:forcetypeassert
}
// If we haven't already removed the remote address, remove it
// from the connection map.
if !c.rmraddr.Load() {
delete(l.conns, c.raddr.String())
c.rmraddr.Store(true)
}
default:
break lclose
}
}
nConns := len(l.conns)
l.connLock.Unlock()
l.connWG.Done()
if nConns == 0 {
// Wait if this is the final connection.
l.readWG.Wait()
if errClose, ok := l.errClose.Load().(error); ok {
err = errClose
}
} else {
err = nil
}
})
return err
}
// Addr returns the listener's network address.
func (l *listener) Addr() net.Addr {
return l.pConn.LocalAddr()
}
// ListenConfig stores options for listening to an address.
type ListenConfig struct {
// Backlog defines the maximum length of the queue of pending
// connections. It is equivalent of the backlog argument of
// POSIX listen function.
// If a connection request arrives when the queue is full,
// the request will be silently discarded, unlike TCP.
// Set zero to use default value 128 which is same as Linux default.
Backlog int
// AcceptFilter determines whether the new conn should be made for
// the incoming packet. If not set, any packet creates new conn.
AcceptFilter func([]byte) bool
// DatagramRouter routes an incoming datagram to a connection by extracting
// an identifier from the its paylod
DatagramRouter func([]byte) (string, bool)
// ConnectionIdentifier extracts an identifier from an outgoing packet. If
// the identifier is not already associated with the connection, it will be
// added.
ConnectionIdentifier func([]byte) (string, bool)
}
// Listen creates a new listener based on the ListenConfig.
func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) {
if lc.Backlog == 0 {
lc.Backlog = defaultListenBacklog
}
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
l := &listener{
pConn: conn,
acceptCh: make(chan *PacketConn, lc.Backlog),
conns: make(map[string]*PacketConn),
doneCh: make(chan struct{}),
acceptFilter: lc.AcceptFilter,
datagramRouter: lc.DatagramRouter,
connIdentifier: lc.ConnectionIdentifier,
readDoneCh: make(chan struct{}),
}
l.accepting.Store(true)
l.connWG.Add(1)
l.readWG.Add(2) // wait readLoop and Close execution routine
go l.readLoop()
go func() {
l.connWG.Wait()
if err := l.pConn.Close(); err != nil {
l.errClose.Store(err)
}
l.readWG.Done()
}()
return l, nil
}
// Listen creates a new listener using default ListenConfig.
func Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) {
return (&ListenConfig{}).Listen(network, laddr)
}
// readLoop dispatches packets to the proper connection, creating a new one if
// necessary, until all connections are closed.
func (l *listener) readLoop() {
defer l.readWG.Done()
defer close(l.readDoneCh)
buf := make([]byte, receiveMTU)
for {
n, raddr, err := l.pConn.ReadFrom(buf)
if err != nil {
l.errRead.Store(err)
return
}
conn, ok, err := l.getConn(raddr, buf[:n])
if err != nil {
continue
}
if ok {
_, _ = conn.buffer.WriteTo(buf[:n], raddr)
}
}
}
// getConn gets an existing connection or creates a new one.
func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) {
l.connLock.Lock()
defer l.connLock.Unlock()
// If we have a custom resolver, use it.
if l.datagramRouter != nil {
if id, ok := l.datagramRouter(buf); ok {
if conn, ok := l.conns[id]; ok {
return conn, true, nil
}
}
}
// If we don't have a custom resolver, or we were unable to find an
// associated connection, fall back to remote address.
conn, ok := l.conns[raddr.String()]
if !ok {
if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok {
return nil, false, ErrClosedListener
}
if l.acceptFilter != nil {
if !l.acceptFilter(buf) {
return nil, false, nil
}
}
conn = l.newPacketConn(raddr)
select {
case l.acceptCh <- conn:
l.conns[raddr.String()] = conn
default:
return nil, false, ErrListenQueueExceeded
}
}
return conn, true, nil
}
// PacketConn is a net.PacketConn implementation that is able to dictate its
// routing ID via an alternate identifier from its remote address. Internal
// buffering is performed for reads, and writes are passed through to the
// underlying net.PacketConn.
type PacketConn struct {
listener *listener
raddr net.Addr
rmraddr atomic.Bool
id atomic.Value
buffer *idtlsnet.PacketBuffer
doneCh chan struct{}
doneOnce sync.Once
writeDeadline *deadline.Deadline
}
// newPacketConn constructs a new PacketConn.
func (l *listener) newPacketConn(raddr net.Addr) *PacketConn {
return &PacketConn{
listener: l,
raddr: raddr,
buffer: idtlsnet.NewPacketBuffer(),
doneCh: make(chan struct{}),
writeDeadline: deadline.New(),
}
}
// ReadFrom reads a single packet payload and its associated remote address from
// the underlying buffer.
func (c *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
return c.buffer.ReadFrom(p)
}
// WriteTo writes len(p) bytes from p to the specified address.
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// If we have a connection identifier, check to see if the outgoing packet
// sets it.
if c.listener.connIdentifier != nil {
id := c.id.Load()
// Only update establish identifier if we haven't already done so.
if id == nil {
candidate, ok := c.listener.connIdentifier(p)
// If we have an identifier, add entry to connection map.
if ok {
c.listener.connLock.Lock()
c.listener.conns[candidate] = c
c.listener.connLock.Unlock()
c.id.Store(candidate)
}
}
// If we are writing to a remote address that differs from the initial,
// we have an alternate identifier established, and we haven't already
// freed the remote address, free the remote address to be used by
// another connection.
// Note: this strategy results in holding onto a remote address after it
// is potentially no longer in use by the client. However, releasing
// earlier means that we could miss some packets that should have been
// routed to this connection. Ideally, we would drop the connection
// entry for the remote address as soon as the client starts sending
// using an alternate identifier, but in practice this proves
// challenging because any client could spoof a connection identifier,
// resulting in the remote address entry being dropped prior to the
// "real" client transitioning to sending using the alternate
// identifier.
if id != nil && !c.rmraddr.Load() && addr.String() != c.raddr.String() {
c.listener.connLock.Lock()
delete(c.listener.conns, c.raddr.String())
c.rmraddr.Store(true)
c.listener.connLock.Unlock()
}
}
select {
case <-c.writeDeadline.Done():
return 0, context.DeadlineExceeded
default:
}
return c.listener.pConn.WriteTo(p, addr)
}
// Close closes the conn and releases any Read calls
func (c *PacketConn) Close() error {
var err error
c.doneOnce.Do(func() {
c.listener.connWG.Done()
close(c.doneCh)
c.listener.connLock.Lock()
// If we have an alternate identifier, remove it from the connection
// map.
if id := c.id.Load(); id != nil {
delete(c.listener.conns, id.(string)) //nolint:forcetypeassert
}
// If we haven't already removed the remote address, remove it from the
// connection map.
if !c.rmraddr.Load() {
delete(c.listener.conns, c.raddr.String())
c.rmraddr.Store(true)
}
nConns := len(c.listener.conns)
c.listener.connLock.Unlock()
if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok {
// Wait if this is the final connection
c.listener.readWG.Wait()
if errClose, ok := c.listener.errClose.Load().(error); ok {
err = errClose
}
} else {
err = nil
}
if errBuf := c.buffer.Close(); errBuf != nil && err == nil {
err = errBuf
}
})
return err
}
// LocalAddr implements net.PacketConn.LocalAddr.
func (c *PacketConn) LocalAddr() net.Addr {
return c.listener.pConn.LocalAddr()
}
// SetDeadline implements net.PacketConn.SetDeadline.
func (c *PacketConn) SetDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return c.SetReadDeadline(t)
}
// SetReadDeadline implements net.PacketConn.SetReadDeadline.
func (c *PacketConn) SetReadDeadline(t time.Time) error {
return c.buffer.SetReadDeadline(t)
}
// SetWriteDeadline implements net.PacketConn.SetWriteDeadline.
func (c *PacketConn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
// Write deadline of underlying connection should not be changed
// since the connection can be shared.
return nil
}