forked from pixelbender/go-stun
/
conn.go
176 lines (154 loc) · 3.9 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package stun
import (
"bufio"
"io"
"net"
"time"
)
// Config represents a STUN connection configuration.
type Config struct {
// GetAuthKey returns a key for a MESSAGE-INTEGRITY attribute generation and validation.
// Key = MD5(username ":" realm ":" SASLprep(password)) for long-term credentials.
// Key = SASLprep(password) for short-term credentials.
// SASLprep is defined in RFC 4013.
// The Username and Password fields are ignored if GetAuthKey is defined.
GetAuthKey func(attrs Attributes) ([]byte, error)
// GetAttributeCodec returns STUN attribute codec for the specified attribute type.
// Using stun.GetAttributeCodec if GetAttributeCodec is nil.
GetAttributeCodec func(at uint16) AttrCodec
// Fingerprint controls whether a FINGERPRINT attribute will be generated.
Fingerprint bool
// Software is a value for SOFTWARE attribute.
Software string
}
func (c *Config) getAuthKey(attrs Attributes) ([]byte, error) {
if c != nil && c.GetAuthKey != nil {
return c.GetAuthKey(attrs)
}
return nil, nil
}
func (c *Config) getAttrCodec(at uint16) AttrCodec {
if c != nil && c.GetAttributeCodec != nil {
return c.GetAttributeCodec(at)
}
return GetAttributeCodec(at)
}
var DefaultConfig = &Config{
GetAttributeCodec: GetAttributeCodec,
}
// A Conn represents the STUN connection and implements the STUN protocol over net.Conn interface.
type Conn struct {
net.Conn
config *Config
dec *Decoder
enc *Encoder
cr connReader
reliable bool
key []byte
}
// NewConn creates a Conn connection over the c with specified configuration.
func NewConn(inner net.Conn, config *Config) *Conn {
if config == nil {
config = DefaultConfig
}
c := &Conn{
Conn: inner,
config: config,
dec: NewDecoder(config),
enc: NewEncoder(config),
}
if _, ok := inner.(net.PacketConn); ok {
c.cr = newPacketReader(inner)
c.reliable = false
} else {
c.cr = newStreamReader(inner)
c.reliable = true
}
return c
}
// ReadMessage reads STUN messages from the connection.
func (c *Conn) ReadMessage() (*Message, error) {
b, err := c.cr.PeekMessageBytes()
if err != nil {
return nil, err
}
msg, err := c.dec.Decode(b, c.key)
return msg, err
}
// WriteMessage writes the STUN message to the connection.
func (c *Conn) WriteMessage(msg *Message) error {
b, err := c.enc.Encode(msg)
if err != nil {
return err
}
if _, err = c.Write(b); err != nil {
return err
}
return nil
}
type connReader interface {
// PeekMessageBytes returns the bytes, containing a STUN message.
// The bytes stop being valid at the next read call.
PeekMessageBytes() ([]byte, error)
}
var bufferSize = 1400
// streamReader reads a STUN message transmitted over a stream-oriented network.
type streamReader struct {
*bufio.Reader
r io.Reader
skip int
}
func newStreamReader(r io.Reader) *streamReader {
if tcp, ok := r.(*net.TCPConn); ok {
tcp.SetKeepAlive(true)
tcp.SetKeepAlivePeriod(30 * time.Second)
}
return &streamReader{bufio.NewReaderSize(r, bufferSize), r, 0}
}
func (c *streamReader) Read(b []byte) (int, error) {
c.discard()
return c.Read(b)
}
func (c *streamReader) PeekMessageBytes() ([]byte, error) {
c.discard()
h, err := c.Peek(4)
if err != nil {
return nil, err
}
if be.Uint16(h)&0xc000 != 0 {
return nil, ErrFormat
}
n := int(be.Uint16(h[2:])) + 20
b, err := c.Peek(n)
if err != nil {
return nil, err
}
c.skip = n
return b, nil
}
func (c *streamReader) discard() {
if c.skip > 0 {
c.Discard(c.skip)
c.skip = 0
}
}
// packetConn reads a STUN message transmitted over a packet-oriented network.
type packetReader struct {
io.Reader
buf []byte
}
func newPacketReader(r io.Reader) *packetReader {
return &packetReader{r, make([]byte, bufferSize)}
}
func (c *packetReader) PeekMessageBytes() ([]byte, error) {
n, err := c.Read(c.buf)
if err != nil {
return nil, err
}
b := c.buf[:n]
l := int(be.Uint16(b[2:])) + 20
if n < l {
return nil, ErrTruncated
}
return b, nil
}