Skip to content

Commit

Permalink
properly set deadline during TLS handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik committed May 13, 2021
1 parent 7c58901 commit 4dd6c25
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 27 deletions.
8 changes: 6 additions & 2 deletions dtls/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ func (s *Session) WriteMessage(req *pool.Message) error {
if err != nil {
return fmt.Errorf("cannot marshal: %w", err)
}
return s.connection.WriteWithContext(req.Context(), data)
err = s.connection.WriteWithContext(req.Context(), data)
if err != nil {
return fmt.Errorf("cannot write to connection: %w", err)
}
return err
}

func (s *Session) MaxMessageSize() int {
Expand Down Expand Up @@ -121,7 +125,7 @@ func (s *Session) Run(cc *client.ClientConn) (err error) {
readBuf := m
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
if err != nil {
return err
return fmt.Errorf("cannot read from connection: %w", err)
}
readBuf = readBuf[:readLen]
err = cc.Process(readBuf)
Expand Down
69 changes: 63 additions & 6 deletions net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Conn struct {
onReadTimeout func() error
onWriteTimeout func() error

handshake func() error
readBuffer *bufio.Reader
lock sync.Mutex
}
Expand Down Expand Up @@ -50,6 +51,10 @@ func NewConn(c net.Conn, opts ...ConnOption) *Conn {
onReadTimeout: cfg.onReadTimeout,
onWriteTimeout: cfg.onWriteTimeout,
}
if v, ok := c.(interface{ Handshake() error }); ok {
connection.handshake = v.Handshake
}

return &connection
}

Expand Down Expand Up @@ -84,8 +89,12 @@ func (c *Conn) WriteWithContext(ctx context.Context, data []byte) error {
return ctx.Err()
default:
}
err := c.doHandshakeLocked(ctx, c.onWriteTimeout)
if err != nil {
return fmt.Errorf("cannot TLS handshake: %w", err)
}
deadline := time.Now().Add(c.heartBeat)
err := c.connection.SetWriteDeadline(deadline)
err = c.connection.SetWriteDeadline(deadline)
if err != nil {
return fmt.Errorf("cannot set write deadline for connection: %w", err)
}
Expand All @@ -104,7 +113,7 @@ func (c *Conn) WriteWithContext(ctx context.Context, data []byte) error {
}
continue
}
return fmt.Errorf("cannot write to connection: %w", err)
return err
}
written += n
}
Expand All @@ -124,20 +133,68 @@ func (c *Conn) ReadFullWithContext(ctx context.Context, buffer []byte) error {
return nil
}

// During handshake wee need to use setDeadline because https://github.com/golang/go/issues/31224
// added comment in https://github.com/golang/go/commit/c9b9cd73bb7a7828d34f4a7844f16c3fbc0674dd
func (c *Conn) doHandshakeLocked(ctx context.Context, onTimeout func() error) error {
if c.handshake == nil {
return nil
}
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return ctx.Err()
}
return fmt.Errorf("handshake failed")
default:
}
deadline := time.Now().Add(c.heartBeat)
err := c.connection.SetDeadline(deadline)
if err != nil {
return fmt.Errorf("cannot set deadline for handshake: %w", err)
}
err = c.handshake()
if err != nil {
if isTemporary(err, deadline) {
if onTimeout != nil {
err := onTimeout()
if err != nil {
return fmt.Errorf("on timeout returns error: %w", err)
}
}
continue
}
}
return err
}
}

func (c *Conn) doHandshake(ctx context.Context, onTimeout func() error) error {
if c.handshake == nil {
return nil
}
c.lock.Lock()
defer c.lock.Unlock()
return c.doHandshakeLocked(ctx, onTimeout)
}

// ReadWithContext reads stream with context.
func (c *Conn) ReadWithContext(ctx context.Context, buffer []byte) (int, error) {
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return -1, fmt.Errorf("cannot read from connection: %v", ctx.Err())
return -1, ctx.Err()
}
return -1, fmt.Errorf("cannot read from connection")
default:
}

err := c.doHandshake(ctx, c.onReadTimeout)
if err != nil {
return -1, fmt.Errorf("cannot TLS handshake: %w", err)
}
deadline := time.Now().Add(c.heartBeat)
err := c.connection.SetReadDeadline(deadline)
err = c.connection.SetReadDeadline(deadline)
if err != nil {
return -1, fmt.Errorf("cannot set read deadline for connection: %w", err)
}
Expand All @@ -152,7 +209,7 @@ func (c *Conn) ReadWithContext(ctx context.Context, buffer []byte) (int, error)
}
continue
}
return -1, fmt.Errorf("cannot read from connection: %w", err)
return -1, err
}
return n, err
}
Expand Down
34 changes: 17 additions & 17 deletions tcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,25 @@ func (s *Server) Serve(l Listener) error {
}
if rw != nil {
wg.Add(1)
var cc *ClientConn
monitor := s.createInactivityMonitor()
opts := []coapNet.ConnOption{
coapNet.WithHeartBeat(s.heartBeat),
coapNet.WithOnReadTimeout(func() error {
monitor.CheckInactivity(cc)
return nil
}),
}
cc = s.createClientConn(coapNet.NewConn(rw, opts...), monitor)
if s.onNewClientConn != nil {
if tlscon, ok := rw.(*tls.Conn); ok {
s.onNewClientConn(cc, tlscon)
} else {
s.onNewClientConn(cc, nil)
}
}
go func() {
defer wg.Done()
var cc *ClientConn
monitor := s.createInactivityMonitor()
opts := []coapNet.ConnOption{
coapNet.WithHeartBeat(s.heartBeat),
coapNet.WithOnReadTimeout(func() error {
monitor.CheckInactivity(cc)
return nil
}),
}
cc = s.createClientConn(coapNet.NewConn(rw, opts...), monitor)
if s.onNewClientConn != nil {
if tlscon, ok := rw.(*tls.Conn); ok {
s.onNewClientConn(cc, tlscon)
} else {
s.onNewClientConn(cc, nil)
}
}
err := cc.Run()
if err != nil {
s.errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
Expand Down
8 changes: 6 additions & 2 deletions tcp/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ func (s *Session) WriteMessage(req *pool.Message) error {
if err != nil {
return fmt.Errorf("cannot marshal: %w", err)
}
return s.connection.WriteWithContext(req.Context(), data)
err = s.connection.WriteWithContext(req.Context(), data)
if err != nil {
return fmt.Errorf("cannot write to connection: %w", err)
}
return err
}

func (s *Session) sendCSM() error {
Expand Down Expand Up @@ -359,7 +363,7 @@ func (s *Session) Run(cc *ClientConn) (err error) {
}
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
if err != nil {
return err
return fmt.Errorf("cannot read from connection: %w", err)
}
if readLen > 0 {
buffer.Write(readBuf[:readLen])
Expand Down

0 comments on commit 4dd6c25

Please sign in to comment.