Skip to content

Commit

Permalink
add transport, handshake transport and server
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelbender committed Mar 22, 2017
1 parent 19678ae commit 7855ee1
Show file tree
Hide file tree
Showing 12 changed files with 724 additions and 574 deletions.
11 changes: 0 additions & 11 deletions dtls/ciphersuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,3 @@ func (c *cipherSuite) prf(ver uint16, result, secret []byte, params ...[]byte) {
var (
errUnsupportedKeyExchangeAlgorithm = errors.New("dtls: unsupported key exchange algorithm")
)

/*
func cipherAES(key, iv []byte, read bool) interface{} {
block, _ := aes.NewCipher(key)
if read {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
*/
41 changes: 17 additions & 24 deletions dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (

func NewClient(conn net.Conn, config *Config) (*Conn, error) {
c := newConn(conn, config)
h := &clientHandshake{handshakeProtocol: c.newHandshake()}
h := &clientHandshake{}
h.transport = c.transport
if err := h.handshake(); err != nil {
return nil, err
}
return c, nil
}

type clientHandshake struct {
*handshakeProtocol
ver uint16
handshakeTransport
suite *cipherSuite
masterSecret []byte
}
Expand Down Expand Up @@ -48,24 +48,22 @@ func (c *clientHandshake) handshake() (err error) {
ch.signatureAlgorithms = supportedSignatureAlgorithms
}
var (
req *helloVerifyRequest
sh *serverHello
skey *serverKeyExchange
scert *certificate
creq *certificateRequest
)
c.reset(true)
c.write(&handshake{typ: handshakeClientHello, message: ch})
if err = c.flight(func(m *handshake) (done bool, err error) {
c.prepare(&handshake{typ: handshakeClientHello, message: ch})
if err = c.roundTrip(func(m *handshake) (done bool, err error) {
switch m.typ {
case handshakeHelloVerifyRequest:
var r *helloVerifyRequest
if r, err = parseHelloVerifyRequest(m.raw); err != nil {
if req, err = parseHelloVerifyRequest(m.raw); err != nil {
break
}
// TODO: reset finished mac
ch.cookie = clone(r.cookie)
c.reset(true)
c.write(&handshake{typ: handshakeClientHello, message: ch})
ch.cookie = clone(req.cookie)
c.reset()
c.prepare(&handshake{typ: handshakeClientHello, message: ch})
case handshakeServerHello:
sh, err = parseServerHello(m.raw)
case handshakeCertificate:
Expand All @@ -78,7 +76,7 @@ func (c *clientHandshake) handshake() (err error) {
done = true
default:
c.sendAlert(alertUnexpectedMessage)
return false, fmt.Errorf("dtls: unexpected message: 0x%x", m.typ)
return false, errUnexpectedMessage
}
if err != nil {
c.sendAlert(alertDecodeError)
Expand Down Expand Up @@ -114,12 +112,8 @@ func (c *clientHandshake) handshake() (err error) {
}
// TODO: check renegotiation

c.reset(false)
//
////preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0])
//
if creq != nil {
c.write(&handshake{typ: handshakeCertificate, message: &certificate{}})
c.prepare(&handshake{typ: handshakeCertificate, message: &certificate{}})
// TODO: write peer certificate chain
}

Expand All @@ -138,7 +132,7 @@ func (c *clientHandshake) handshake() (err error) {
c.sendAlert(alertInternalError)
return
}
c.write(&handshake{typ: handshakeClientKeyExchange, message: ckey})
c.prepare(&handshake{typ: handshakeClientKeyExchange, message: ckey})
case keyECDH:
// TODO: implement
c.sendAlert(alertInternalError)
Expand All @@ -148,12 +142,11 @@ func (c *clientHandshake) handshake() (err error) {
return errUnsupportedKeyExchangeAlgorithm
}

c.changeCipherSpec()
c.prepareRecord(&record{typ: recordChangeCipherSpec, raw: changeCipherSpec})

c.tx.epoch++
c.write(&handshake{typ: handshakeFinished, raw: c.finishedHash()})
//c.write(&handshake{typ: handshakeFinished, raw: c.finishedHash()})

c.tx.writeFlight(c.enc.raw, c.enc.rec)
//c.tx.writeFlight(c.enc.raw, c.enc.rec)
time.Sleep(time.Second)
/*
return c.flight(func(m *handshake) (done bool, err error) {
Expand All @@ -163,7 +156,7 @@ func (c *clientHandshake) handshake() (err error) {
}

func (c *clientHandshake) finishedHash() []byte {
return c.suite.finishedHash(c.ver, c.masterSecret, clientFinished, c.buf.Bytes())
return c.suite.finishedHash(c.ver, c.masterSecret, clientFinished, c.log)
}

func (c *clientHandshake) newMasterSecretRSA(ch *clientHello, sh *serverHello, pub *rsa.PublicKey) ([]byte, error) {
Expand Down
36 changes: 16 additions & 20 deletions dtls/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,34 @@ package dtls

import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"time"
)

var DefaultConfig = &Config{
var defaultConfig = &Config{
Rand: rand.Reader,
Time: time.Now,
MTU: 1400,
RetransmissionTimeout: 500 * time.Millisecond,
ReadTimeout: 15 * time.Second,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
},
SRTPProtectionProfiles: []uint16{
SRTP_AES128_CM_HMAC_SHA1_80,
SRTP_AES128_CM_HMAC_SHA1_32,
TLS_RSA_WITH_AES_128_CBC_SHA,
//tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
//tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
//tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
//tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
//tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
//tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
//tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
//tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
//tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
//tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
//tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
//tls.TLS_RSA_WITH_AES_128_CBC_SHA,
//tls.TLS_RSA_WITH_AES_256_CBC_SHA,
//tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
},
MinVersion: VersionDTLS10,
MaxVersion: VersionDTLS12,
Expand Down
121 changes: 3 additions & 118 deletions dtls/conn.go
Original file line number Diff line number Diff line change
@@ -1,139 +1,24 @@
package dtls

import (
"io"
"net"
)

type Conn struct {
net.Conn
config *Config
rx receiver
tx transmitter
*transport
}

func newConn(c net.Conn, config *Config) *Conn {
if config == nil {
config = DefaultConfig
config = defaultConfig
}
return &Conn{
c,
config,
receiver{r: c},
transmitter{w: c, mtu: config.MTU},
&transport{Conn: c, config: config},
}
}

func (c *Conn) newHandshake() *handshakeProtocol {
return &handshakeProtocol{
Conn: c,
enc: handshakeEncoder{
mtu: c.config.MTU,
seq: 0,
},
dec: handshakeDecoder{
seq: 0,
},
}
}

func (c *Conn) sendAlert(a uint8) error {
return c.tx.write(&record{
typ: recordAlert,
ver: VersionDTLS10,
raw: []byte{levelError, a},
})
}

func (c *Conn) Close() error {
// TODO: send alert only if handshake done
c.sendAlert(alertCloseNotify)
return c.Conn.Close()
}

type receiver struct {
r io.Reader
epoch uint16
seq int64
mask int64
buf []byte
raw []byte
}

func (rx *receiver) read() (r *record, err error) {
if rx.buf == nil {
rx.buf = make([]byte, 4096)
}
for {
if len(rx.raw) > 0 {
r, rx.raw, err = parseRecord(rx.raw)
if err == nil && rx.check(r) {
return r, nil
}
}
n, err := rx.r.Read(rx.buf)
if err != nil {
return nil, err
}
rx.raw = rx.buf[:n]
}
}

func (rx *receiver) check(r *record) bool {
if r.epoch != rx.epoch {
return false
}
d := r.seq - rx.seq
if d > 0 {
if d < 64 {
rx.mask = (rx.mask << uint(d)) | 1
} else {
rx.mask = 1
}
rx.seq = r.seq
return true
}
if d = -d; d >= 64 {
return false
}
if b := int64(1) << uint(d); rx.mask&b == 0 {
rx.mask |= b
return true
}
return false
}

type transmitter struct {
w io.Writer
mtu int
epoch uint16
seq int64
}

func (tx *transmitter) write(r *record) error {
r.epoch, r.seq = tx.epoch, tx.seq
tx.seq++
_, err := tx.w.Write(r.marshal(nil))
return err
}

func (tx *transmitter) writeFlight(raw []byte, rec []int) (err error) {
last, sent := 0, 0
for _, to := range rec {
v := raw[last:to]
put6(v[5:], tx.seq)
tx.seq++
if to-sent > tx.mtu {
if _, err = tx.w.Write(raw[sent:last]); err != nil {
return
}
sent = last
}
last = to
}
if sent == last {
return nil
}
_, err = tx.w.Write(raw[sent:last])
return
}

0 comments on commit 7855ee1

Please sign in to comment.