-
Notifications
You must be signed in to change notification settings - Fork 397
/
mux.go
167 lines (138 loc) · 3.22 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package listenmux
import (
"context"
"fmt"
"io"
"net"
"sync"
"github.com/zeebo/errs"
)
// Closed is returned by routed listeners when the mux is closed.
var Closed = errs.New("listener closed")
// Mux lets one multiplex a listener into different listeners based on the first
// bytes sent on the connection.
type Mux struct {
base net.Listener
prefixLen int
addr net.Addr
def *listener
mu sync.Mutex
routes map[string]*listener
once sync.Once
done chan struct{}
err error
}
// New creates a mux that reads the prefixLen bytes from any connections Accepted by the
// passed in listener and dispatches to the appropriate route.
func New(base net.Listener, prefixLen int) *Mux {
addr := base.Addr()
return &Mux{
base: base,
prefixLen: prefixLen,
addr: addr,
def: newListener(addr),
routes: make(map[string]*listener),
done: make(chan struct{}),
}
}
//
// set up the routes
//
// Default returns the net.Listener that is used if no route matches.
func (m *Mux) Default() net.Listener { return m.def }
// Route returns a listener that will be used if the first bytes are the given prefix. The
// length of the prefix must match the original passed in prefixLen.
func (m *Mux) Route(prefix string) net.Listener {
m.mu.Lock()
defer m.mu.Unlock()
if len(prefix) != m.prefixLen {
panic(fmt.Sprintf("invalid prefix: has %d but needs %d bytes", len(prefix), m.prefixLen))
}
lis, ok := m.routes[prefix]
if !ok {
lis = newListener(m.addr)
m.routes[prefix] = lis
go m.monitorListener(prefix, lis)
}
return lis
}
//
// run the muxer
//
// Run calls listen on the provided listener and passes connections to the routed
// listeners.
func (m *Mux) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go m.monitorContext(ctx)
go m.monitorBase()
<-m.done
m.mu.Lock()
defer m.mu.Unlock()
for _, lis := range m.routes {
<-lis.done
}
return m.err
}
func (m *Mux) monitorContext(ctx context.Context) {
<-ctx.Done()
m.once.Do(func() {
_ = m.base.Close() // TODO(jeff): do we care about this error?
close(m.done)
})
}
func (m *Mux) monitorBase() {
for {
conn, err := m.base.Accept()
if err != nil {
// TODO(jeff): temporary errors?
m.once.Do(func() {
m.err = err
close(m.done)
})
return
}
go m.routeConn(conn)
}
}
func (m *Mux) monitorListener(prefix string, lis *listener) {
select {
case <-m.done:
lis.once.Do(func() {
if m.err != nil {
lis.err = m.err
} else {
lis.err = Closed
}
close(lis.done)
})
case <-lis.done:
}
m.mu.Lock()
delete(m.routes, prefix)
m.mu.Unlock()
}
func (m *Mux) routeConn(conn net.Conn) {
buf := make([]byte, m.prefixLen)
if _, err := io.ReadFull(conn, buf); err != nil {
// TODO(jeff): how to handle these errors?
_ = conn.Close()
return
}
m.mu.Lock()
lis, ok := m.routes[string(buf)]
if !ok {
lis = m.def
conn = newPrefixConn(buf, conn)
}
m.mu.Unlock()
// TODO(jeff): a timeout for the listener to get to the conn?
select {
case <-lis.done:
// TODO(jeff): better way to signal to the caller the listener is closed?
_ = conn.Close()
case lis.Conns() <- conn:
}
}