Skip to content

Commit

Permalink
Consistent error logging on init handshake errors
Browse files Browse the repository at this point in the history
Currently, the inbound side logging may log different messages in
different cases. We may log "Couldn't create new TChannelConnection",
"Failed during connection handshake", or both depending on where the
error occurred in the handshake.

Instead, make this consistent with the outbound side to always only
log "Failed during connection handshake". This also ensures that we send
back an error when we get the wrong message type as the first frame on
an init req.
  • Loading branch information
prashantv committed Mar 1, 2017
1 parent ec3f3f8 commit c18a0bd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
2 changes: 0 additions & 2 deletions channel.go
Expand Up @@ -454,8 +454,6 @@ func (ch *Channel) serve() {
OnExchangeUpdated: ch.exchangeUpdated,
}
if _, err := ch.inboundHandshake(context.Background(), netConn, events); err != nil {
// Server is getting overloaded - begin rejecting new connections
ch.log.WithFields(ErrField(err)).Error("Couldn't create new TChannelConnection for incoming conn.")
netConn.Close()
}
}()
Expand Down
2 changes: 1 addition & 1 deletion connection_test.go
Expand Up @@ -777,7 +777,7 @@ func TestConnectTimeout(t *testing.T) {
}

func TestParallelConnectionAccepts(t *testing.T) {
opts := testutils.NewOpts().AddLogFilter("Couldn't create new TChannelConnection", 1)
opts := testutils.NewOpts().AddLogFilter("Failed during connection handshake", 1)
testutils.WithTestServer(t, opts, func(ts *testutils.TestServer) {
testutils.RegisterEcho(ts.Server(), nil)

Expand Down
10 changes: 10 additions & 0 deletions init_test.go
Expand Up @@ -83,6 +83,16 @@ func TestUnexpectedInitReq(t *testing.T) {
errCode: ErrCodeProtocol,
},
},
{
name: "unexpected message type",
initMsg: &pingReq{
id: 1,
},
expectedError: errorMessage{
id: 1,
errCode: ErrCodeProtocol,
},
},
}

for _, tt := range tests {
Expand Down
33 changes: 19 additions & 14 deletions preinit_connection.go
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/binary"
"fmt"
"io"
"math"
"net"
"strconv"
"time"
Expand All @@ -34,9 +35,7 @@ import (
func (ch *Channel) outboundHandshake(ctx context.Context, c net.Conn, outboundHP string, events connectionEvents) (_ *Connection, err error) {
defer setInitDeadline(ctx, c)()
defer func() {
if err != nil {
err = ch.initError(c, outbound, 1, err)
}
err = ch.initError(c, outbound, 1, err)
}()

msg := &initReq{initMessage: ch.getInitMessage(ctx, 1)}
Expand Down Expand Up @@ -67,19 +66,19 @@ func (ch *Channel) outboundHandshake(ctx context.Context, c net.Conn, outboundHP
}

func (ch *Channel) inboundHandshake(ctx context.Context, c net.Conn, events connectionEvents) (_ *Connection, err error) {
id := uint32(math.MaxUint32)

defer setInitDeadline(ctx, c)()
defer func() {
err = ch.initError(c, inbound, id, err)
}()

req := &initReq{}
id, err := ch.readMessage(c, req)
id, err = ch.readMessage(c, req)
if err != nil {
return nil, err
}

defer func() {
if err != nil {
err = ch.initError(c, inbound, id, err)
}
}()
if req.Version < CurrentProtocolVersion {
return nil, unsupportedProtocolVersion(req.Version)
}
Expand Down Expand Up @@ -122,10 +121,16 @@ func (ch *Channel) getInitMessage(ctx context.Context, id uint32) initMessage {
}

func (ch *Channel) initError(c net.Conn, connDir connectionDirection, id uint32, err error) error {
ch.log.WithFields(
LogField{"connectionDirection", connDir},
if err == nil {
return nil
}

ch.log.WithFields(LogFields{
{"connectionDirection", connDir},
{"localAddr", c.LocalAddr()},
{"remoteAddr", c.RemoteAddr()},
ErrField(err),
).Error("Failed during connection handshake.")
}...).Error("Failed during connection handshake.")

if ne, ok := err.(net.Error); ok && ne.Timeout() {
err = ErrTimeout
Expand Down Expand Up @@ -162,9 +167,9 @@ func (ch *Channel) readMessage(c net.Conn, msg message) (uint32, error) {

if frame.Header.messageType != msg.messageType() {
if frame.Header.messageType == messageTypeError {
return 0, readError(frame)
return frame.Header.ID, readError(frame)
}
return 0, NewSystemError(ErrCodeProtocol, "expected message type %v, got %v", msg.messageType(), frame.Header.messageType)
return frame.Header.ID, NewSystemError(ErrCodeProtocol, "expected message type %v, got %v", msg.messageType(), frame.Header.messageType)
}

return frame.Header.ID, frame.read(msg)
Expand Down

0 comments on commit c18a0bd

Please sign in to comment.