diff --git a/dtls/ciphersuite.go b/dtls/ciphersuite.go index f1e1976..d4a490c 100644 --- a/dtls/ciphersuite.go +++ b/dtls/ciphersuite.go @@ -87,14 +87,3 @@ func (c *cipherSuite) prf(ver uint16, result, secret []byte, params ...[]byte) { var ( errUnsupportedKeyExchangeAlgorithm = errors.New("dtls: unsupported key exchange algorithm") ) - -/* -func cipherAES(key, iv []byte, read bool) interface{} { - block, _ := aes.NewCipher(key) - if read { - return cipher.NewCBCDecrypter(block, iv) - } - return cipher.NewCBCEncrypter(block, iv) -} - -*/ diff --git a/dtls/client.go b/dtls/client.go index 3bfbf35..a34e6df 100644 --- a/dtls/client.go +++ b/dtls/client.go @@ -10,7 +10,8 @@ import ( func NewClient(conn net.Conn, config *Config) (*Conn, error) { c := newConn(conn, config) - h := &clientHandshake{handshakeProtocol: c.newHandshake()} + h := &clientHandshake{} + h.transport = c.transport if err := h.handshake(); err != nil { return nil, err } @@ -18,8 +19,7 @@ func NewClient(conn net.Conn, config *Config) (*Conn, error) { } type clientHandshake struct { - *handshakeProtocol - ver uint16 + handshakeTransport suite *cipherSuite masterSecret []byte } @@ -48,24 +48,22 @@ func (c *clientHandshake) handshake() (err error) { ch.signatureAlgorithms = supportedSignatureAlgorithms } var ( + req *helloVerifyRequest sh *serverHello skey *serverKeyExchange scert *certificate creq *certificateRequest ) - c.reset(true) - c.write(&handshake{typ: handshakeClientHello, message: ch}) - if err = c.flight(func(m *handshake) (done bool, err error) { + c.prepare(&handshake{typ: handshakeClientHello, message: ch}) + if err = c.roundTrip(func(m *handshake) (done bool, err error) { switch m.typ { case handshakeHelloVerifyRequest: - var r *helloVerifyRequest - if r, err = parseHelloVerifyRequest(m.raw); err != nil { + if req, err = parseHelloVerifyRequest(m.raw); err != nil { break } - // TODO: reset finished mac - ch.cookie = clone(r.cookie) - c.reset(true) - c.write(&handshake{typ: handshakeClientHello, message: ch}) + ch.cookie = clone(req.cookie) + c.reset() + c.prepare(&handshake{typ: handshakeClientHello, message: ch}) case handshakeServerHello: sh, err = parseServerHello(m.raw) case handshakeCertificate: @@ -78,7 +76,7 @@ func (c *clientHandshake) handshake() (err error) { done = true default: c.sendAlert(alertUnexpectedMessage) - return false, fmt.Errorf("dtls: unexpected message: 0x%x", m.typ) + return false, errUnexpectedMessage } if err != nil { c.sendAlert(alertDecodeError) @@ -114,12 +112,8 @@ func (c *clientHandshake) handshake() (err error) { } // TODO: check renegotiation - c.reset(false) - // - ////preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) - // if creq != nil { - c.write(&handshake{typ: handshakeCertificate, message: &certificate{}}) + c.prepare(&handshake{typ: handshakeCertificate, message: &certificate{}}) // TODO: write peer certificate chain } @@ -138,7 +132,7 @@ func (c *clientHandshake) handshake() (err error) { c.sendAlert(alertInternalError) return } - c.write(&handshake{typ: handshakeClientKeyExchange, message: ckey}) + c.prepare(&handshake{typ: handshakeClientKeyExchange, message: ckey}) case keyECDH: // TODO: implement c.sendAlert(alertInternalError) @@ -148,12 +142,11 @@ func (c *clientHandshake) handshake() (err error) { return errUnsupportedKeyExchangeAlgorithm } - c.changeCipherSpec() + c.prepareRecord(&record{typ: recordChangeCipherSpec, raw: changeCipherSpec}) - c.tx.epoch++ - c.write(&handshake{typ: handshakeFinished, raw: c.finishedHash()}) + //c.write(&handshake{typ: handshakeFinished, raw: c.finishedHash()}) - c.tx.writeFlight(c.enc.raw, c.enc.rec) + //c.tx.writeFlight(c.enc.raw, c.enc.rec) time.Sleep(time.Second) /* return c.flight(func(m *handshake) (done bool, err error) { @@ -163,7 +156,7 @@ func (c *clientHandshake) handshake() (err error) { } func (c *clientHandshake) finishedHash() []byte { - return c.suite.finishedHash(c.ver, c.masterSecret, clientFinished, c.buf.Bytes()) + return c.suite.finishedHash(c.ver, c.masterSecret, clientFinished, c.log) } func (c *clientHandshake) newMasterSecretRSA(ch *clientHello, sh *serverHello, pub *rsa.PublicKey) ([]byte, error) { diff --git a/dtls/config.go b/dtls/config.go index 09794d7..37cce58 100644 --- a/dtls/config.go +++ b/dtls/config.go @@ -2,38 +2,34 @@ package dtls import ( "crypto/rand" - "crypto/tls" "crypto/x509" "errors" "io" "time" ) -var DefaultConfig = &Config{ +var defaultConfig = &Config{ Rand: rand.Reader, Time: time.Now, MTU: 1400, RetransmissionTimeout: 500 * time.Millisecond, ReadTimeout: 15 * time.Second, CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - }, - SRTPProtectionProfiles: []uint16{ - SRTP_AES128_CM_HMAC_SHA1_80, - SRTP_AES128_CM_HMAC_SHA1_32, + TLS_RSA_WITH_AES_128_CBC_SHA, + //tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + //tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + //tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + //tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + //tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + //tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + //tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + //tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + //tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + //tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + //tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + //tls.TLS_RSA_WITH_AES_128_CBC_SHA, + //tls.TLS_RSA_WITH_AES_256_CBC_SHA, + //tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, }, MinVersion: VersionDTLS10, MaxVersion: VersionDTLS12, diff --git a/dtls/conn.go b/dtls/conn.go index cba8d01..4c1de01 100644 --- a/dtls/conn.go +++ b/dtls/conn.go @@ -1,139 +1,24 @@ package dtls import ( - "io" "net" ) type Conn struct { - net.Conn - config *Config - rx receiver - tx transmitter + *transport } func newConn(c net.Conn, config *Config) *Conn { if config == nil { - config = DefaultConfig + config = defaultConfig } return &Conn{ - c, - config, - receiver{r: c}, - transmitter{w: c, mtu: config.MTU}, + &transport{Conn: c, config: config}, } } -func (c *Conn) newHandshake() *handshakeProtocol { - return &handshakeProtocol{ - Conn: c, - enc: handshakeEncoder{ - mtu: c.config.MTU, - seq: 0, - }, - dec: handshakeDecoder{ - seq: 0, - }, - } -} - -func (c *Conn) sendAlert(a uint8) error { - return c.tx.write(&record{ - typ: recordAlert, - ver: VersionDTLS10, - raw: []byte{levelError, a}, - }) -} - func (c *Conn) Close() error { // TODO: send alert only if handshake done c.sendAlert(alertCloseNotify) return c.Conn.Close() } - -type receiver struct { - r io.Reader - epoch uint16 - seq int64 - mask int64 - buf []byte - raw []byte -} - -func (rx *receiver) read() (r *record, err error) { - if rx.buf == nil { - rx.buf = make([]byte, 4096) - } - for { - if len(rx.raw) > 0 { - r, rx.raw, err = parseRecord(rx.raw) - if err == nil && rx.check(r) { - return r, nil - } - } - n, err := rx.r.Read(rx.buf) - if err != nil { - return nil, err - } - rx.raw = rx.buf[:n] - } -} - -func (rx *receiver) check(r *record) bool { - if r.epoch != rx.epoch { - return false - } - d := r.seq - rx.seq - if d > 0 { - if d < 64 { - rx.mask = (rx.mask << uint(d)) | 1 - } else { - rx.mask = 1 - } - rx.seq = r.seq - return true - } - if d = -d; d >= 64 { - return false - } - if b := int64(1) << uint(d); rx.mask&b == 0 { - rx.mask |= b - return true - } - return false -} - -type transmitter struct { - w io.Writer - mtu int - epoch uint16 - seq int64 -} - -func (tx *transmitter) write(r *record) error { - r.epoch, r.seq = tx.epoch, tx.seq - tx.seq++ - _, err := tx.w.Write(r.marshal(nil)) - return err -} - -func (tx *transmitter) writeFlight(raw []byte, rec []int) (err error) { - last, sent := 0, 0 - for _, to := range rec { - v := raw[last:to] - put6(v[5:], tx.seq) - tx.seq++ - if to-sent > tx.mtu { - if _, err = tx.w.Write(raw[sent:last]); err != nil { - return - } - sent = last - } - last = to - } - if sent == last { - return nil - } - _, err = tx.w.Write(raw[sent:last]) - return -} diff --git a/dtls/conn_test.go b/dtls/conn_test.go index 2bcf140..ea5ef6e 100644 --- a/dtls/conn_test.go +++ b/dtls/conn_test.go @@ -2,152 +2,59 @@ package dtls import ( "encoding/hex" + "math/rand" "net" "testing" ) -type dumpConn struct { - net.Conn - t *testing.T -} - -func (c *dumpConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if err == nil { - c.t.Logf("Read: %s", hex.Dump(b[:n])) - } - return -} - -func (c *dumpConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if err == nil { - c.t.Logf("Write: %s", hex.Dump(b)) - } - return -} - func _TestClientWithOpenSSL(t *testing.T) { conn, err := net.Dial("udp", "127.0.0.1:4444") if err != nil { t.Fatal(err) } - config := DefaultConfig.Clone() + conn = &logConn{ + conn, + t.Logf, + } + config := defaultConfig.Clone() config.InsecureSkipVerify = true - c, err := NewClient(&dumpConn{conn, t}, config) + c, err := NewClient(conn, config) if err != nil { t.Fatal(err) } defer c.Close() } -/* -func generateCerts() error { - pk, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return err - } - sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - return err - } - tpl := x509.Certificate{ - SerialNumber: sn, - Subject: pkix.Name{ - CommonName: "go-dtls", - }, - IsCA: true, +type lossConn struct { + net.Conn + rate float64 +} - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Minute), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - DNSNames: []string{"go-dtls"}, - } - der, err := x509.CreateCertificate(rand.Reader, &tpl, &tpl, pk.PublicKey, pk) - if err != nil { - return err +func (c *lossConn) Read(b []byte) (n int, err error) { + for { + n, err = c.Conn.Read(b) + if rand.Float64() > c.rate { + break + } } - dir, err := ioutil.TempDir("", "test") - defer os.RemoveAll(dir) + return } -*/ -func TestHandshakeDecoder(t *testing.T) { - frag := []string{ - "0b0002c700010000000000e60002c40002c1308202bd308201a5a003020102020100300d06092a864886f70d01010b05003022310b30090603550406130253453113301106035504030c0a4f70656e576562525443301e170d3137303330373132303235355a170d3138303330373132303235355a3022310b30090603550406130253453113301106035504030c0a4f70656e57656252544330820122300d06092a864886f70d01010105000382010f003082010a0282010100c2717a632ea4618e599ed6173dfafef22b4f8df27120e30978052c3532c41532ef7466cdf1fe70f6d0554069cb0dfec3ac99f93fabece26a", - "0b0002c700010000e60000e7bb9fcefdae4197cee480c5dd0aa76ca2a9ae85287176180778ed7ce4b9c10bf3ee6426827cb4f4c933c6dd9c4e94dd43aa59d7c60a8a33db961a6dba5243de7ddeab2d9f13ed74a6c0259aa4358e8b25632a5f11e9692118ed1f084fb6953c9a1507825d919394c438cf277c149488c0628e6e3ddf2c1de4a4570b711cc51a6e0747e9aea0fc4687eeb10f45945eee41b147a0d697a825e3817e6b7d0a0ec5bd382c60e0f7c1ef1acb820ed28fdb2c5fa5abb1c8d5cddf9bf3f4309687baec0b2cb97cbf62f22fb30203010001300d06092a864886f70d01010b0500038201010061aa714fdc32", - "0b0002c700010001cd0000e76b9a4b20a46e7264713326d9f4e3e5ca6b972daa4bdf318fc3e9c6b1de1b1f136272b6768ca74d49c7a1ea1296244e4f5a6b01e8938106b8d80fa43ebe0794c9d81c35d65cb62f40754e7a0d2d1ccd46fe5d79670be3c9b9c1fc30245542557f39222bec1a688445ff0f74015ecb7b4cfebc60916a48b48415d064c873fe68838d1cb7f00ecd8b3a0b9069c8a820ce75f7675275cafc50e30cab3c97400cef81475b984ec1f71676e55a6275a919f2a3d3e6d6da23a2eb91442693796e1ab69143700b7bcfa41cec8f5a0ce1ae15bbc671be681308e4f0f40d82deafbdb818d1eac53fa1f57c91", - "0b0002c700010002b4000013bfd8f25c142f1d8416053b375e9ef44fbd06fd", - } - for _, seq := range [][]int{ - {0, 1, 2, 3}, - {3, 2, 1, 0}, - {2, 0, 1, 0, 2, 0, 3}, - {0, 1, 2, 1, 0, 1, 3}, - } { - d := &handshakeDecoder{seq: 1} - for _, i := range seq { - b, _ := hex.DecodeString(frag[i]) - d.parse(b) - } - h := d.read() - if h == nil || d.seq != 2 { - t.Fatal("defragmentation:", seq) - } - c, err := parseCertificate(h.raw) - if err != nil { - t.Fatal(err) - } - if len(c.cert) == 0 { - t.Fatal("no certificate") - } +type logConn struct { + net.Conn + Logf func(format string, args ...interface{}) +} + +func (c *logConn) Read(b []byte) (n int, err error) { + if n, err = c.Conn.Read(b); err == nil { + c.Logf("Read: %s", hex.Dump(b[:n])) } + return } -func TestHandshakeEncoder(t *testing.T) { - for _, mtu := range []int{1024, 512, 256, 128, 64, 32} { - e := &handshakeEncoder{mtu: mtu} - e.writeRecord(&record{ - typ: recordHandshake, - ver: DefaultConfig.MinVersion, - payload: &handshake{ - typ: handshakeClientHello, - message: &clientHello{ - ver: DefaultConfig.MaxVersion, - random: make([]byte, 32), - cipherSuites: DefaultConfig.CipherSuites, - compMethods: supportedCompression, - extensions: &extensions{ - renegotiationSupported: true, - srtpProtectionProfiles: DefaultConfig.SRTPProtectionProfiles, - extendedMasterSecret: true, - sessionTicket: true, - signatureAlgorithms: supportedSignatureAlgorithms, - supportedPoints: supportedPointFormats, - supportedCurves: supportedCurves, - }, - }, - }, - }) - b := e.raw - d := &handshakeDecoder{} - for len(b) > 0 { - r, next, err := parseRecord(b) - if err != nil { - t.Fatal(err) - } - d.parse(r.raw) - b = next - } - h := d.read() - if h == nil { - t.Fatal("no message") - } - _, err := parseClientHello(h.raw) - if err != nil { - t.Fatal(err) - } +func (c *logConn) Write(b []byte) (n int, err error) { + if n, err = c.Conn.Write(b); err == nil { + c.Logf("Write: %s", hex.Dump(b)) } + return } diff --git a/dtls/handshake.go b/dtls/handshake.go index 4cede2e..de17cd3 100644 --- a/dtls/handshake.go +++ b/dtls/handshake.go @@ -1,29 +1,25 @@ package dtls import ( - "bytes" "crypto/x509" "errors" - "io" - "log" - "sort" - "time" ) var ( errHandshakeFormat = errors.New("dtls: handshake format error") - errCertificateRequestFormat = errors.New("dtls: certificateRequest format error") - errClientHelloFormat = errors.New("dtls: clientHello format error") - errServerHelloFormat = errors.New("dtls: serverHello format error") - errHelloVerifyRequestFormat = errors.New("dtls: helloVerifyRequest format error") + errCertificateRequestFormat = errors.New("dtls: certificate_request format error") + errClientHelloFormat = errors.New("dtls: client_hello format error") + errServerHelloFormat = errors.New("dtls: server_hello format error") + errHelloVerifyRequestFormat = errors.New("dtls: hello_verify_request format error") errCertificateFormat = errors.New("dtls: certificate format error") - errServerKeyExchangeFormat = errors.New("dtls: serverKeyExchange format error") - errClientKeyExchangeFormat = errors.New("dtls: clientkeyExchange format error") - errCertificateVerifyFormat = errors.New("dtls: certificateVerify format error") + errServerKeyExchangeFormat = errors.New("dtls: server_key_exchange format error") + errClientKeyExchangeFormat = errors.New("dtls: client_key_exchange format error") + errCertificateVerifyFormat = errors.New("dtls: certificate_verify format error") - errHandshakeSequence = errors.New("dtls: handshake sequence error") - errHandshakeTimeout = errors.New("dtls: handshake timeout") - errHandshakeError = errors.New("dtls: handshake error") + errHandshakeSequence = errors.New("dtls: handshake sequence error") + errHandshakeMessageOutOfBounds = errors.New("dtls: handshake message is out of bounds") + errHandshakeMessageTooBig = errors.New("dtls: handshake message is too big") + errHandshakeTimeout = errors.New("dtls: handshake timeout") errUnexpectedMessage = errors.New("dtls: unexpected message") ) @@ -54,98 +50,6 @@ var supportedCompression = []uint8{ compNone, } -type handshakeProtocol struct { - *Conn - enc handshakeEncoder - dec handshakeDecoder - buf bytes.Buffer -} - -func (c *handshakeProtocol) reset(clearHash bool) { - if clearHash { - c.buf.Reset() - } - c.enc.reset() -} - -func (c *handshakeProtocol) write(h *handshake) { - c.enc.w = &c.buf - c.enc.write(&record{ - typ: recordHandshake, - ver: c.config.MinVersion, - epoch: c.tx.epoch, - }, h) -} - -func (c *handshakeProtocol) changeCipherSpec() { - c.enc.w = &c.buf - c.enc.writeRecord(&record{ - typ: recordChangeCipherSpec, - ver: c.config.MinVersion, - epoch: c.tx.epoch, - raw: changeCipherSpec, - }) -} - -func (c *handshakeProtocol) flight(handle func(h *handshake) (bool, error)) error { - start, done, rto := time.Now(), false, c.config.RetransmissionTimeout - if err := c.tx.writeFlight(c.enc.raw, c.enc.rec); err != nil { - return err - } - for !done { - d := c.config.ReadTimeout - time.Since(start) - if d < 0 { - return errHandshakeTimeout - } - if d > rto { - d = rto - } - c.SetReadDeadline(time.Now().Add(d)) - r, err := c.rx.read() - if err != nil { - if t, ok := err.(interface { - Timeout() bool - }); ok && t.Timeout() { - rto <<= 1 - if err = c.tx.writeFlight(c.enc.raw, c.enc.rec); err == nil { - continue - } - } - return err - } - switch r.typ { - case recordHandshake: - if !c.dec.parse(r.raw) { - break - } - if h := c.dec.read(); h != nil { - h.raw = clone(h.raw) // TODO: append and slice finish hash buffer - c.buf.Write(h.raw) - done, err = handle(h) - if err != nil { - return err - } - } - case recordAlert: - a, err := parseAlert(r.raw) - if err != nil { - return err - } - if a.level == levelError { - // TODO: check if warnings corrupt handshake - return a - } - default: - log.Printf("Unexpected record: %v", r.typ) - } - } - return nil -} - -type marshaler interface { - marshal([]byte) []byte -} - type handshake struct { typ uint8 len int @@ -462,165 +366,3 @@ func (r *certificateRequest) marshal(b []byte) []byte { b = pack(b, r.types, nil) return pack2(b, r.names, nil) } - -type handshakeEncoder struct { - mtu int - pos int - raw []byte - rec []int - seq int - w io.Writer -} - -func (e *handshakeEncoder) reset() { - if len(e.raw) > 0 { - e.raw = e.raw[:0] - } - if len(e.rec) > 0 { - e.rec = e.rec[:0] - } - e.pos = 0 -} - -func (e *handshakeEncoder) write(r *record, h *handshake) { - r.payload, h.seq = h, e.seq - e.seq++ - e.writeRecord(r) -} - -func (e *handshakeEncoder) writeRecord(r *record) { - if e.mtu < 26 { - panic("dtls: mtu is too small") - } - b := e.raw - from := len(b) - b = r.prepare(b) - to := len(b) - if e.w != nil && r.typ == recordHandshake { - e.w.Write(b[from:to]) - } - n, max := to-from, e.mtu-e.pos - l := n - 25 - if r.typ == recordHandshake { - put3(b[from+14:], l) - } - if n > max { - e.pos, max = 0, e.mtu - } - if n <= max || r.typ != recordHandshake { - e.pos += n - e.rec = append(e.rec, to) - e.raw = b - return - } - m := max - 25 - c := l / m - if l > m*c { - c++ - } - _, b = grow(b, (c-1)*25) - put2(b[from+11:], m+12) - put3(b[from+22:], m) - v := b[from:] - for i := c - 1; i > 0; i-- { - p, off := v[i*max:], i*m - if len(p) > max { - p = p[:max] - } - s := copy(p[25:], v[25+off:]) - copy(p, v[:19]) - put2(p[11:], s+12) - put3(p[19:], off) - put3(p[22:], s) - } - for i, m := 0, len(b); i < c; i++ { - to := from + (i+1)*max - if to > m { - to = m - } - e.rec = append(e.rec, to) - } - e.raw = b -} - -type handshakeDecoder struct { - seq int - que [16]*queue -} - -func (d *handshakeDecoder) parse(b []byte) bool { - h, err := parseHandshake(b) - if err != nil { - log.Printf("dtls: handshake parse error: %v", err) - return false - } - ds := h.seq - d.seq - if ds < 0 || ds > 15 { - log.Printf("dtls: handshake sequence %d > %d", h.seq, d.seq) - return false - } - i := h.seq & 0xf - q := d.que[i] - if q == nil { - if h.len < 0 || h.len > 0x1000 { - log.Printf("dtls: handshake message is too big: %d bytes", h.len) - return false - } - q = &queue{raw: make([]byte, h.len)} - d.que[i] = q - } else { - for _, it := range q.h { - if it.off == h.off && len(h.raw) == len(it.raw) { - log.Printf("dtls: handshake message duplicate") - return false - } - } - } - if m := h.off + len(h.raw); h.off < 0 || m > len(q.raw) { - log.Printf("dtls: handshake message out of bounds %d:%d max %d", h.off, m, len(q.raw)) - return false - } - copy(q.raw[h.off:], h.raw) - q.h = append(q.h, h) - sort.Sort(q) - return true -} - -func (d *handshakeDecoder) read() *handshake { - n, q := 0, d.que[d.seq&0xf] - if q == nil { - return nil - } - for _, h := range q.h { - if next := h.off + len(h.raw); h.off <= n && next > n { - n = next - } - } - if n == len(q.raw) { - h := q.h[0] - h.off, h.raw = 0, q.raw - d.que[d.seq&0xf] = nil - d.seq++ - return h - } - return nil -} - -type queue struct { - h []*handshake - raw []byte -} - -func (q *queue) Len() int { - return len(q.h) -} - -func (q *queue) Swap(i, j int) { - r := q.h - r[i], r[j] = r[j], r[i] -} - -func (q *queue) Less(i, j int) bool { - a, b := q.h[i], q.h[j] - return a.off < b.off -} diff --git a/dtls/record.go b/dtls/record.go index f74a30e..97d0cdd 100644 --- a/dtls/record.go +++ b/dtls/record.go @@ -16,6 +16,10 @@ const ( var changeCipherSpec = []byte{1} +type marshaler interface { + marshal([]byte) []byte +} + type record struct { typ uint8 ver, epoch uint16 @@ -24,9 +28,10 @@ type record struct { payload marshaler } -func parseRecord(b []byte) (*record, []byte, error) { - if len(b) < 13 { - return nil, nil, errRecordFormat +func parseRecord(b []byte) (*record, int, error) { + n := len(b) + if n < 13 { + return nil, 0, errRecordFormat } _ = b[10] r := &record{ @@ -36,9 +41,9 @@ func parseRecord(b []byte) (*record, []byte, error) { seq: int64(b[5])<<40 | int64(b[6])<<32 | int64(b[7])<<24 | int64(b[8])<<16 | int64(b[9])<<8 | int64(b[10]), } if r.raw, b = split2(b[11:]); r.raw == nil { - return nil, nil, errRecordFormat + return nil, 0, errRecordFormat } - return r, b, nil + return r, n - len(b), nil } func (r *record) prepare(b []byte) []byte { diff --git a/dtls/record_test.go b/dtls/record_test.go index 7560ded..2dcae08 100644 --- a/dtls/record_test.go +++ b/dtls/record_test.go @@ -102,16 +102,19 @@ func TestHelloVerifyRequest(t *testing.T) { } func TestServerKeyExchange(t *testing.T) { - d := &handshakeDecoder{seq: 2} + d := &handshakeTransport{} + d.in.seq = 2 for _, it := range []string{ "16feff000000000000000e00f20c00014700020000000000e60300174104cbaecd1b61e5a9480a702836f0a8a0a44f8f0c88e8009f45acfacf654d8e47fe4005cd215f9a5c38cb8ad5f5d528bea7ec2ff3f09633c57941287fee09e5effd01003738960f95e19967fbf1e36d8082ae9c8311126a0f695134feeb06ab205b34e201cf59bb07b1e57bcf809c7452f5824854c0c51a5471f93d03430bdc61a5a21b45bde88b967e22ce5549bed6bce8c3696fa5f9c7f4662eaa039cd904a6e9e6aaf4618db14b46f35057a54ec04121c5ba9b4c2d1de61d588fe2ddd04913f9f880f5fe3cebb26c49647d2a5c898fabf34edfea5c4cc9b4991c1de62be4dc3aa8", "16feff000000000000000f006d0c00014700020000e60000611b89d720a8722ced8270a728a34fb49d01b3ae61fbff3e85bb6f15fb09a4d406e9146f5122d51c9beee570e999db2238c2e55df2a801f355bf73d02a1e154b2f859a3579e5a3927a16c0d0794780db346381342cc72ddb7f6ab75cff18533c9ed7", } { b, _ := hex.DecodeString(it) r, _, _ := parseRecord(b) - d.parse(r.raw) + if err := d.parse(r.raw); err != nil { + t.Fatal(err) + } } - h := d.read() + h := d.next() if h == nil { t.Fatal("no message") } diff --git a/dtls/server.go b/dtls/server.go new file mode 100644 index 0000000..c0b8a01 --- /dev/null +++ b/dtls/server.go @@ -0,0 +1,178 @@ +package dtls + +import ( + "encoding/hex" + "errors" + "io" + "log" + "net" + "sync" + "time" +) + +var ( + errNotImplemented = errors.New("dtls: not implemented") +) + +func Listen(network, laddr string, config *Config) (net.Listener, error) { + addr, err := net.ResolveUDPAddr(network, laddr) + if err != nil { + return nil, err + } + // TODO: use self signed certificate if not specified + c, err := net.ListenUDP(network, addr) + if err != nil { + return nil, err + } + return NewListener(c, config), nil +} + +func NewListener(c *net.UDPConn, config *Config) *Listener { + if config == nil { + config = defaultConfig + } + l := &Listener{ + c: c, + config: config, + accept: make(chan *conn, 16), + conns: make(map[string]*conn), + } + go l.serveConn() + return l +} + +type Listener struct { + c *net.UDPConn + config *Config + mu sync.RWMutex + accept chan *conn + conns map[string]*conn +} + +func (l *Listener) Accept() (net.Conn, error) { + // TODO: handle multiple goroutines + c, ok := <-l.accept + if !ok { + return nil, io.EOF + } + return c, nil +} + +func (l *Listener) Addr() net.Addr { + return l.c.LocalAddr() +} + +func (l *Listener) Close() error { + // TODO: close acceptors and readers + return l.c.Close() +} + +func (l *Listener) serveConn() error { + var ( + m, buf []byte + v [18]byte + ) + for { + if len(buf) < 4096 { + buf = make([]byte, 1<<20) + } + n, addr, err := l.c.ReadFromUDP(buf) + if err != nil { + return err + } + if n == 0 { + continue + } + m, buf = buf[:n], buf[n:] + v[0], v[1] = uint8(addr.Port>>8), uint8(addr.Port) + id := v[:copy(v[2:], addr.IP)] + l.mu.RLock() + c := l.conns[string(id)] + l.mu.RUnlock() + if c == nil { + l.mu.Lock() + if c = l.conns[string(id)]; c == nil { + c = newServerConn(l, addr, string(id)) + l.conns[string(id)] = c + } + l.mu.Unlock() + } + if err = c.serve(m); err != nil { + c.Close() + } + } +} + +func (l *Listener) closeConn(id string) { + l.mu.Lock() + delete(l.conns, id) + l.mu.Unlock() +} + +type conn struct { + l *Listener + addr *net.UDPAddr + id string + serve func(b []byte) error + in chan []byte +} + +func newServerConn(l *Listener, addr *net.UDPAddr, id string) *conn { + return &conn{ + l: l, + addr: addr, + id: id, + serve: func(b []byte) error { + log.Printf("Read: %s", hex.Dump(b)) + return nil + }, + in: make(chan []byte, 64), + } +} + +func (c *conn) SetDeadline(t time.Time) error { + return errNotImplemented +} + +func (c *conn) SetReadDeadline(t time.Time) error { + return errNotImplemented +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + return errNotImplemented +} + +func (c *conn) Read(p []byte) (n int, err error) { + b, ok := <-c.in + if !ok { + return 0, io.EOF + } + return copy(p, b), nil +} + +func (c *conn) Write(p []byte) (int, error) { + return c.l.c.WriteToUDP(p, c.addr) +} + +func (c *conn) LocalAddr() net.Addr { + return c.l.Addr() +} + +func (c *conn) RemoteAddr() net.Addr { + return c.addr +} + +func (c *conn) Close() error { + c.l.closeConn(c.id) + return nil +} + +type serverHandshake struct { + handshakeTransport + suite *cipherSuite + masterSecret []byte +} + +func (c *serverHandshake) handshake() error { + return nil +} diff --git a/dtls/server_test.go b/dtls/server_test.go new file mode 100644 index 0000000..e7fb866 --- /dev/null +++ b/dtls/server_test.go @@ -0,0 +1,20 @@ +package dtls + +import ( + "log" + "testing" +) + +func _TestServer(t *testing.T) { + l, err := Listen("udp4", "127.0.0.1:4444", nil) + if err != nil { + t.Fatal(err) + } + log.Printf("%v", l.Addr()) + for { + _, err := l.Accept() + if err != nil { + t.Fatal(err) + } + } +} diff --git a/dtls/transport.go b/dtls/transport.go new file mode 100644 index 0000000..d6f3867 --- /dev/null +++ b/dtls/transport.go @@ -0,0 +1,343 @@ +package dtls + +import ( + "net" + "sort" + "sync/atomic" + "time" +) + +type transport struct { + net.Conn + config *Config + ver uint16 + epoch uint16 + rx struct { + seq int64 + mask int64 + buf []byte + pos int + } + tx struct { + seq int64 + } +} + +func (t *transport) readRecord() (*record, error) { + b := t.rx.buf + if b == nil { + b = make([]byte, 0, 4096) + t.rx.buf = b + } + for { + for t.rx.pos < len(b) { + r, n, err := parseRecord(b[t.rx.pos:]) + if err != nil { + t.rx.pos = len(b) + return r, err + } + t.rx.pos += n + if t.canReceive(r) { + return r, nil + } + } + n, err := t.Read(b[:cap(b)]) + if err != nil { + return nil, err + } + t.rx.buf = b[:n] + } +} + +func (t *transport) canReceive(r *record) bool { + if r.epoch != t.epoch { + return false + } + // TODO: prevent from seq corruption + d := r.seq - t.rx.seq + if d > 0 { + if d < 64 { + t.rx.mask = (t.rx.mask << uint(d)) | 1 + } else { + t.rx.mask = 1 + } + t.rx.seq = r.seq + return true + } + if d = -d; d >= 64 { + return false + } + if b := int64(1) << uint(d); t.rx.mask&b == 0 { + t.rx.mask |= b + return true + } + return false +} + +func (t *transport) sendAlert(alert uint8) error { + return t.writeRecord(&record{ + typ: recordAlert, + ver: t.ver, + raw: []byte{levelError, alert}, + }) +} + +func (t *transport) writeRecord(r *record) error { + r.ver, r.epoch, r.seq = t.ver, t.epoch, atomic.AddInt64(&t.tx.seq, 1)-1 + _, err := t.Write(r.marshal(nil)) + return err +} + +func (t *transport) writeFlight(raw []byte, rec []int) error { + sent, last := 0, 0 + for _, to := range rec { + v := raw[last:to] + put6(v[5:], atomic.AddInt64(&t.tx.seq, 1)-1) + if to-sent > t.config.getMTU() { + if _, err := t.Write(raw[sent:last]); err != nil { + return err + } + sent = last + } + last = to + } + if sent == last { + return nil + } + _, err := t.Write(raw[sent:last]) + return err +} + +type handshakeTransport struct { + *transport + log []byte + out struct { + seq int + raw []byte + rec []int + last int + } + in struct { + seq int + queue [16]*handshakeQueue + } +} + +func (t *handshakeTransport) reset() { + if t.log != nil { + t.log = t.log[:0] + } + t.clearFlight() +} + +func (t *handshakeTransport) roundTrip(handle func(h *handshake) (bool, error)) error { + defer t.clearFlight() + if err := t.sendFlight(); err != nil { + return err + } + var ( + start = time.Now() + rto = t.config.RetransmissionTimeout + ) + defer t.SetReadDeadline(time.Time{}) + for { + d := t.config.ReadTimeout - time.Since(start) + if d < 0 { + return errHandshakeTimeout + } + if d > rto { + d = rto + } + t.SetReadDeadline(time.Now().Add(d)) + h, err := t.readHandshake() + if err != nil { + if e, ok := err.(interface { + Timeout() bool + }); ok && e.Timeout() { + rto <<= 1 + if err = t.sendFlight(); err != nil { + continue + } + } + return err + } + if done, err := handle(h); err != nil { + return err + } else if done { + break + } + } + return nil +} + +func (t *handshakeTransport) prepare(h *handshake) { + h.seq = t.out.seq + t.out.seq++ + t.prepareRecord(&record{typ: recordHandshake, payload: h}) +} + +func (t *handshakeTransport) prepareRecord(r *record) { + r.ver, r.epoch = t.ver, t.epoch + b := t.out.raw + pos := len(b) + b = r.prepare(b) + v := b[pos:] + mtu := t.config.getMTU() + n, max := len(v), mtu-len(b)+t.out.last + if n > max { + t.out.last, max = len(b), mtu + } + if r.typ == recordHandshake { + put3(v[14:], n-25) + t.log = append(t.log, v[25:]...) + } + if n <= max || r.typ != recordHandshake { + t.out.rec = append(t.out.rec, t.out.last+n) + t.out.raw = b + return + } + l, m := n-25, max-25 + c := l / m + s := l - m*c + if s == 0 { + c-- + } + _, b = grow(b, c*25) + dst := b[pos:] + for i := c; i >= 0; i-- { + d, off, s := dst[i*max:], i*m, m + if len(d) > max { + d = d[:max] + } + if i > 0 { + s = copy(d[25:], v[25+off:]) + copy(d, v[:19]) + } + put2(d[11:], s+12) + put3(d[19:], off) + put3(d[22:], s) + } + for i := 0; i < c; i++ { + t.out.last += max + t.out.rec = append(t.out.rec, t.out.last) + } + if s > 0 { + t.out.rec = append(t.out.rec, t.out.last+s) + } + t.out.raw = b +} + +func (t *handshakeTransport) sendFlight() error { + return t.writeFlight(t.out.raw, t.out.rec) +} + +func (t *handshakeTransport) clearFlight() { + if t.out.raw != nil { + t.out.raw = t.out.raw[:0] + } + if t.out.rec != nil { + t.out.rec = t.out.rec[:0] + } + t.out.last = 0 +} + +func (t *handshakeTransport) parse(b []byte) error { + h, err := parseHandshake(b) + if err != nil { + return err + } + ds := h.seq - t.in.seq + if ds < 0 || ds > 15 { + return errHandshakeSequence + } + i := h.seq & 0xf + q := t.in.queue[i] + if q == nil { + if h.len < 0 || h.len > 0x1000 { + return errHandshakeMessageTooBig + } + q = &handshakeQueue{ + raw: make([]byte, h.len), + } + t.in.queue[i] = q + } + if m := h.off + len(h.raw); h.off < 0 || m > len(q.raw) { + return errHandshakeMessageOutOfBounds + } + copy(q.raw[h.off:], h.raw) + q.h = append(q.h, h) + sort.Sort(q) + return nil +} + +func (t *handshakeTransport) next() *handshake { + id := t.in.seq & 0xf + q := t.in.queue[id] + if q == nil { + return nil + } + last := 0 + for _, h := range q.h { + if next := h.off + len(h.raw); h.off <= last && next > last { + last = next + } + } + if last == len(q.raw) { + h := q.h[0] + h.off, h.raw = 0, q.raw + t.in.queue[id] = nil + t.in.seq++ + return h + } + return nil +} + +func (t *handshakeTransport) readHandshake() (*handshake, error) { + for { + h := t.next() + if h != nil { + t.log = append(t.log, h.raw...) + h.raw = clone(h.raw) + return h, nil + } + r, err := t.readRecord() + if err != nil { + return nil, err + } + switch r.typ { + case recordAlert: + a, err := parseAlert(r.raw) + if err != nil { + return nil, err + } + if a.level == levelError { + return nil, a + } + case recordHandshake: + if err = t.parse(r.raw); err != nil { + return nil, err + } + default: + return nil, errUnexpectedMessage + } + } +} + +type handshakeQueue struct { + h []*handshake + raw []byte +} + +func (q *handshakeQueue) Len() int { + return len(q.h) +} + +func (q *handshakeQueue) Swap(i, j int) { + r := q.h + r[i], r[j] = r[j], r[i] +} + +func (q *handshakeQueue) Less(i, j int) bool { + a, b := q.h[i], q.h[j] + return a.off < b.off +} diff --git a/dtls/transport_test.go b/dtls/transport_test.go new file mode 100644 index 0000000..89d2d34 --- /dev/null +++ b/dtls/transport_test.go @@ -0,0 +1,89 @@ +package dtls + +import ( + "encoding/hex" + "testing" +) + +func TestHandshakeDefragment(t *testing.T) { + frag := []string{ + "0b0002c700010000000000e60002c40002c1308202bd308201a5a003020102020100300d06092a864886f70d01010b05003022310b30090603550406130253453113301106035504030c0a4f70656e576562525443301e170d3137303330373132303235355a170d3138303330373132303235355a3022310b30090603550406130253453113301106035504030c0a4f70656e57656252544330820122300d06092a864886f70d01010105000382010f003082010a0282010100c2717a632ea4618e599ed6173dfafef22b4f8df27120e30978052c3532c41532ef7466cdf1fe70f6d0554069cb0dfec3ac99f93fabece26a", + "0b0002c700010000e60000e7bb9fcefdae4197cee480c5dd0aa76ca2a9ae85287176180778ed7ce4b9c10bf3ee6426827cb4f4c933c6dd9c4e94dd43aa59d7c60a8a33db961a6dba5243de7ddeab2d9f13ed74a6c0259aa4358e8b25632a5f11e9692118ed1f084fb6953c9a1507825d919394c438cf277c149488c0628e6e3ddf2c1de4a4570b711cc51a6e0747e9aea0fc4687eeb10f45945eee41b147a0d697a825e3817e6b7d0a0ec5bd382c60e0f7c1ef1acb820ed28fdb2c5fa5abb1c8d5cddf9bf3f4309687baec0b2cb97cbf62f22fb30203010001300d06092a864886f70d01010b0500038201010061aa714fdc32", + "0b0002c700010001cd0000e76b9a4b20a46e7264713326d9f4e3e5ca6b972daa4bdf318fc3e9c6b1de1b1f136272b6768ca74d49c7a1ea1296244e4f5a6b01e8938106b8d80fa43ebe0794c9d81c35d65cb62f40754e7a0d2d1ccd46fe5d79670be3c9b9c1fc30245542557f39222bec1a688445ff0f74015ecb7b4cfebc60916a48b48415d064c873fe68838d1cb7f00ecd8b3a0b9069c8a820ce75f7675275cafc50e30cab3c97400cef81475b984ec1f71676e55a6275a919f2a3d3e6d6da23a2eb91442693796e1ab69143700b7bcfa41cec8f5a0ce1ae15bbc671be681308e4f0f40d82deafbdb818d1eac53fa1f57c91", + "0b0002c700010002b4000013bfd8f25c142f1d8416053b375e9ef44fbd06fd", + } + for _, seq := range [][]int{ + {0, 1, 2, 3}, + {3, 2, 1, 0}, + {2, 0, 1, 0, 2, 0, 3}, + {0, 1, 2, 1, 0, 1, 3}, + } { + tr := &handshakeTransport{} + tr.in.seq = 1 + for _, i := range seq { + b, _ := hex.DecodeString(frag[i]) + if err := tr.parse(b); err != nil && err != errHandshakeSequence { + t.Fatal(err) + } + } + h := tr.next() + if h == nil || tr.in.seq != 2 { + t.Fatal("defragment:", seq) + } + c, err := parseCertificate(h.raw) + if err != nil { + t.Fatal(err) + } + if len(c.cert) == 0 { + t.Fatal("no certificate") + } + } +} + +func TestHandshakeFragment(t *testing.T) { + for mtu := 30; mtu < 256; mtu++ { + c := defaultConfig.Clone() + c.MTU = mtu + tr := &handshakeTransport{ + transport: &transport{ + config: c, + ver: c.MinVersion, + }, + } + tr.prepare(&handshake{ + typ: handshakeClientHello, + message: &clientHello{ + ver: c.MaxVersion, + random: make([]byte, 32), + cipherSuites: c.CipherSuites, + compMethods: supportedCompression, + extensions: &extensions{ + renegotiationSupported: true, + srtpProtectionProfiles: c.SRTPProtectionProfiles, + extendedMasterSecret: true, + sessionTicket: true, + signatureAlgorithms: supportedSignatureAlgorithms, + supportedPoints: supportedPointFormats, + supportedCurves: supportedCurves, + }, + }, + }) + b := tr.out.raw + for len(b) > 0 { + r, n, err := parseRecord(b) + if err != nil { + t.Fatal(err) + } + tr.parse(r.raw) + b = b[n:] + } + h := tr.next() + if h == nil { + t.Fatal("no message") + } + _, err := parseClientHello(h.raw) + if err != nil { + t.Fatal(err) + } + } +}