Skip to content

Commit

Permalink
fix message integrity and fingerprint encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelbender committed Feb 19, 2017
1 parent 02f32a7 commit 3eed7f0
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 30 deletions.
6 changes: 3 additions & 3 deletions stun/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package stun

import (
"errors"
"math/rand"
"net"
"sync"
"time"
"math/rand"
)

var DefaultConfig = &Config{
Expand Down Expand Up @@ -53,7 +53,7 @@ func (c *Config) attrs() []Attr {
return a
}

func (c *Config) clone() *Config {
func (c *Config) Clone() *Config {
r := *c
return &r
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (m *mux) serve(msg *Message, tr Transport) bool {
}

func (m *mux) newTx() *transaction {
tx := &transaction{id: newTransaction()}
tx := &transaction{id: NewTransaction()}
m.Lock()
if m.t == nil {
m.t = make(map[string]*transaction)
Expand Down
16 changes: 8 additions & 8 deletions stun/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ func getString(b []byte) string {
}

func Addr(typ uint16, v net.Addr) Attr {
ip, port := sockAddr(v)
ip, port := SockAddr(v)
return &addr{typ, ip, port}
}

func sockAddr(v net.Addr) (net.IP, int) {
func SockAddr(v net.Addr) (net.IP, int) {
switch a := v.(type) {
case *net.UDPAddr:
return a.IP, a.Port
Expand All @@ -200,12 +200,12 @@ func sockAddr(v net.Addr) (net.IP, int) {
}

func sameAddr(a, b net.Addr) bool {
aip, aport := sockAddr(a)
bip, bport := sockAddr(b)
aip, aport := SockAddr(a)
bip, bport := SockAddr(b)
return aip.Equal(bip) && aport == bport
}

func newAddr(network string, ip net.IP, port int) net.Addr {
func NewAddr(network string, ip net.IP, port int) net.Addr {
switch network {
case "udp", "udp4", "udp6":
return &net.UDPAddr{IP: ip, Port: port}
Expand All @@ -226,7 +226,7 @@ type addr struct {
func (addr *addr) Type() uint16 { return addr.typ }

func (addr *addr) Addr(network string) net.Addr {
return newAddr(network, addr.IP, addr.Port)
return NewAddr(network, addr.IP, addr.Port)
}

func (addr *addr) Xored() bool {
Expand Down Expand Up @@ -360,7 +360,7 @@ func (attr *integrity) Unmarshal(b []byte) error {

func (attr *integrity) MarshalSum(p, raw []byte) []byte {
n := len(raw) - 4
be.PutUint16(raw[2:], uint16(n))
be.PutUint16(raw[2:], uint16(n+4))
return attr.Sum(attr.key, raw[:n], p)
}

Expand Down Expand Up @@ -412,7 +412,7 @@ func (attr *fingerprint) Unmarshal(b []byte) error {

func (attr *fingerprint) MarshalSum(p, raw []byte) []byte {
n := len(raw) - 4
be.PutUint16(raw[2:], uint16(n-16))
be.PutUint16(raw[2:], uint16(n-12))
v := attr.Sum(raw[:n])
r, b := grow(p, 4)
be.PutUint32(b, v)
Expand Down
2 changes: 1 addition & 1 deletion stun/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (c *Conn) RequestTransport(req *Message, to Transport) (res *Message, from
for {
msg := &Message{
req.Type,
newTransaction(),
NewTransaction(),
append(sess.attrs(), req.Attributes...),
}
res, from, err = c.agent.RoundTrip(msg, to)
Expand Down
6 changes: 5 additions & 1 deletion stun/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package stun

import (
"testing"
"time"
)

func TestDiscover(t *testing.T) {
config := DefaultConfig
config.RetransmissionTimeout = 300 * time.Millisecond
config.TransactionTimeout = time.Second
if testing.Verbose() {
DefaultConfig.Logf = t.Logf
config.Logf = t.Logf
} else {
t.Parallel()
}
Expand Down
12 changes: 11 additions & 1 deletion stun/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,18 @@ func (m *Message) CheckIntegrity(key []byte) bool {
return false
}

func (m *Message) CheckFingerprint() bool {
if attr, ok := m.Get(AttrFingerprint).(*fingerprint); ok {
return attr.Check()
}
return false
}

func (m *Message) String() string {
sort.Sort(byPosition(m.Attributes))

// TODO: use sprintf

b := &bytes.Buffer{}
b.WriteString(MethodName(m.Type))
b.WriteByte('{')
Expand Down Expand Up @@ -306,7 +316,7 @@ func (d dict) rand(n int) string {
return string(b)
}

func newTransaction() []byte {
func NewTransaction() []byte {
id := make([]byte, 16)
copy(id, magicCookie)
rand.Read(id[4:]) // TODO: configure random source
Expand Down
34 changes: 34 additions & 0 deletions stun/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,31 @@ func BenchmarkBuffer(b *testing.B) {
}
}

func TestIntegrity(t *testing.T) {
key := []byte("VOkJxbRl1RmTxUk/WvJxBt")
for _, it := range samples[:3] {
d, err := hex.DecodeString(it)
if err != nil {
t.Fatal(err)
}
m, err := UnmarshalMessage(d)
if err != nil {
t.Fatal(err)
}
m.Set(MessageIntegrity(key))
m, err = UnmarshalMessage(m.Marshal(nil))
if err != nil {
t.Fatal(err)
}
if !m.CheckIntegrity(key) {
t.Error("integrity check failed")
}
if !m.CheckFingerprint() {
t.Error("fingerprint check failed")
}
}
}

func TestVectorsSampleRequest(t *testing.T) {
b, err := hex.DecodeString(samples[0])
if err != nil {
Expand All @@ -75,6 +100,9 @@ func TestVectorsSampleRequest(t *testing.T) {
if !m.CheckIntegrity([]byte("VOkJxbRl1RmTxUk/WvJxBt")) {
t.Error("integrity check failed")
}
if !m.CheckFingerprint() {
t.Error("fingerprint check failed")
}
if m.Kind() != KindRequest || m.Method() != MethodBinding {
t.Error("wrong message type:", m.Type)
}
Expand All @@ -99,6 +127,9 @@ func TestVectorsSampleIPv4Response(t *testing.T) {
if !m.CheckIntegrity([]byte("VOkJxbRl1RmTxUk/WvJxBt")) {
t.Error("integrity check failed")
}
if !m.CheckFingerprint() {
t.Error("fingerprint check failed")
}
if m.Kind() != KindResponse || m.Method() != MethodBinding {
t.Error("wrong message type:", m.Type)
}
Expand All @@ -124,6 +155,9 @@ func TestVectorsSampleIPv6Response(t *testing.T) {
if !m.CheckIntegrity([]byte("VOkJxbRl1RmTxUk/WvJxBt")) {
t.Error("integrity check failed")
}
if !m.CheckFingerprint() {
t.Error("fingerprint check failed")
}
if m.Kind() != KindResponse || m.Method() != MethodBinding {
t.Error("wrong message type:", m.Type)
}
Expand Down
16 changes: 10 additions & 6 deletions stun/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func (d *Detector) DiscoverChange(change uint64) error {
if err != nil {
return err
}
ip, port := sockAddr(d.RemoteAddr())
chip, chport := sockAddr(from.RemoteAddr())
ip, port := SockAddr(d.RemoteAddr())
chip, chport := SockAddr(from.RemoteAddr())
if change&ChangeIP != 0 {
if ip.Equal(chip) {
return errors.New("stun: bad response, ip address is not changed")
Expand Down Expand Up @@ -113,7 +113,7 @@ func (d *Detector) Mapping() (string, error) {
if other == nil {
return "", errors.New("stun: bad response, no other address")
}
ip, _ := sockAddr(mapped)
ip, _ := SockAddr(mapped)
if ip.IsLoopback() {
return EndpointIndependent, nil
}
Expand All @@ -122,9 +122,9 @@ func (d *Detector) Mapping() (string, error) {
return EndpointIndependent, nil
}
}
ip, _ = sockAddr(other)
_, port := sockAddr(d.RemoteAddr())
a, err := d.DiscoverOther(newAddr(n, ip, port))
ip, _ = SockAddr(other)
_, port := SockAddr(d.RemoteAddr())
a, err := d.DiscoverOther(NewAddr(n, ip, port))
if err != nil {
return "", err
}
Expand All @@ -141,6 +141,10 @@ func (d *Detector) Mapping() (string, error) {
return AddressPortDependent, nil
}

func LocalAddrs() []*net.IPAddr {
return local
}

var local []*net.IPAddr

func init() {
Expand Down
13 changes: 10 additions & 3 deletions stun/nat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ import (
var once sync.Once

func newDetector(t *testing.T) *Detector {
config := DefaultConfig.clone()
config := DefaultConfig.Clone()
config.RetransmissionTimeout = 300 * time.Millisecond
config.TransactionTimeout = time.Second
config.Software = "client"
if testing.Verbose() {
config.Logf = t.Logf
} else {
t.Parallel()
}
once.Do(func() {
srv := NewServer(nil)
c := config.Clone()
c.Software = "server"

srv := NewServer(c)
loop, _ := net.ResolveIPAddr("ip", "localhost")
for _, it := range append(local, loop) {
for _, port := range []string{"3478", "3479"} {
Expand All @@ -39,8 +43,9 @@ func TestHairpinning(t *testing.T) {
d := newDetector(t)
err := d.Hairpinning()
if err != nil {
t.Fatal(err)
t.Fatalf("hairpinning: %v", err)
}
t.Logf("hairpinning: success")
}

func TestFiltering(t *testing.T) {
Expand All @@ -52,6 +57,7 @@ func TestFiltering(t *testing.T) {
if v != EndpointIndependent {
t.Errorf("Wrong filtering type: %v", v)
}
t.Logf("filtering: %v", v)
}

func TestMapping(t *testing.T) {
Expand All @@ -63,4 +69,5 @@ func TestMapping(t *testing.T) {
if v != EndpointIndependent {
t.Errorf("Wrong mapping type: %v", v)
}
t.Logf("mapping: %v", v)
}
8 changes: 4 additions & 4 deletions stun/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (srv *Server) ServeSTUN(msg *Message, from Transport) {
if msg.Type == MethodBinding {
to := from
mapped := from.RemoteAddr()
ip, port := sockAddr(from.LocalAddr())
ip, port := SockAddr(from.LocalAddr())

res := &Message{
Type: MethodBinding | KindResponse,
Expand All @@ -53,7 +53,7 @@ func (srv *Server) ServeSTUN(msg *Message, from Transport) {

if ch, ok := msg.GetInt(AttrChangeRequest); ok && ch != 0 {
for _, c := range srv.conns {
chip, chport := sockAddr(c.LocalAddr())
chip, chport := SockAddr(c.LocalAddr())
if chip.IsUnspecified() {
continue
}
Expand All @@ -78,12 +78,12 @@ func (srv *Server) ServeSTUN(msg *Message, from Transport) {

other:
for _, a := range srv.conns {
aip, aport := sockAddr(a.LocalAddr())
aip, aport := SockAddr(a.LocalAddr())
if aip.IsUnspecified() || !ip.Equal(aip) || port == aport {
continue
}
for _, b := range srv.conns {
bip, bport := sockAddr(b.LocalAddr())
bip, bport := SockAddr(b.LocalAddr())
if bip.IsUnspecified() || bip.Equal(ip) || aport != bport {
continue
}
Expand Down
6 changes: 3 additions & 3 deletions stun/stun.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func Dial(uri string, config *Config) (*Conn, error) {
return nil, err
}
if auth != nil {
config = config.clone()
config = config.Clone()
config.AuthMethod = auth
}
return NewConn(conn, config), nil
Expand All @@ -95,7 +95,7 @@ func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, e
network = "udp"
}
switch u.Scheme {
case "stun":
case "stun", "turn":
if port == "" {
port = "3478"
}
Expand All @@ -104,7 +104,7 @@ func parseURI(uri string) (secure bool, network, addr string, auth AuthMethod, e
default:
err = errors.New("stun: unsupported transport: " + network)
}
case "stuns":
case "stuns", "turns":
if port == "" {
port = "5478"
}
Expand Down

0 comments on commit 3eed7f0

Please sign in to comment.