Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
bad handshakes are a fatal protocol error
Browse files Browse the repository at this point in the history
  • Loading branch information
blampe committed Oct 26, 2015
1 parent befde1c commit f20c46d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 13 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Expand Up @@ -4,7 +4,8 @@ Changes by Version
0.18.1 (unreleased)
-------------------

- Nothing changed yet.
- Fixed a bug where ``InvalidMessageError`` was being raised instead of
``FatalProtocolError``.


0.18.0 (2015-10-20)
Expand Down
21 changes: 13 additions & 8 deletions tchannel/tornado/connection.py
Expand Up @@ -344,7 +344,7 @@ def initiate_handshake(self, headers):
))
init_res = yield self._recv()
if init_res.message_type != Types.INIT_RES:
raise errors.InvalidMessageError(
raise errors.FatalProtocolError(
"Expected handshake response, got %s" % repr(init_res)
)
self._extract_handshake_headers(init_res)
Expand All @@ -365,7 +365,7 @@ def expect_handshake(self, headers):
"""
init_req = yield self._recv()
if init_req.message_type != Types.INIT_REQ:
raise errors.InvalidMessageError(
raise errors.FatalProtocolError(
"You need to shake my hand first. Got %s" % repr(init_req)
)
self._extract_handshake_headers(init_req)
Expand All @@ -381,12 +381,12 @@ def expect_handshake(self, headers):

def _extract_handshake_headers(self, message):
if not message.host_port:
raise errors.InvalidMessageError(
raise errors.FatalProtocolError(
'Missing required header: host_port'
)

if not message.process_name:
raise errors.InvalidMessageError(
raise errors.FatalProtocolError(
'Missing required header: process_name'
)

Expand Down Expand Up @@ -434,10 +434,15 @@ def outgoing(cls, hostport, process_name=None, serve_hostport=None,

connection = cls(stream, tchannel)
log.debug("Performing handshake with %s", hostport)
yield connection.initiate_handshake(headers={
'host_port': serve_hostport,
'process_name': process_name,
})

try:
yield connection.initiate_handshake(headers={
'host_port': serve_hostport,
'process_name': process_name,
})
except errors.FatalProtocolError:
stream.close()
raise

if handler:
connection.serve(handler)
Expand Down
12 changes: 8 additions & 4 deletions tchannel/tornado/tchannel.py
Expand Up @@ -37,6 +37,7 @@
from ..deprecate import deprecate
from ..enum import enum
from ..errors import AlreadyListeningError
from ..errors import FatalProtocolError
from ..event import EventEmitter
from ..event import EventRegistrar
from ..net import local_ip
Expand Down Expand Up @@ -427,10 +428,13 @@ def handle_stream(self, stream, address):

conn = StreamConnection(connection=stream, tchannel=self.tchannel)

yield conn.expect_handshake(headers={
'host_port': self.tchannel.hostport,
'process_name': self.tchannel.process_name,
})
try:
yield conn.expect_handshake(headers={
'host_port': self.tchannel.hostport,
'process_name': self.tchannel.process_name,
})
except FatalProtocolError:
raise tornado.gen.Return(stream.close())

log.debug(
"Successfully completed handshake with %s:%s (%s)",
Expand Down
79 changes: 79 additions & 0 deletions tests/test_tchannel.py
Expand Up @@ -23,6 +23,7 @@
from __future__ import print_function
from __future__ import unicode_literals

import socket
import subprocess
import textwrap
from mock import MagicMock, patch, ANY
Expand All @@ -32,8 +33,14 @@
import psutil
import pytest
from tornado import gen
from tornado import iostream
from tornado import tcpserver
from tornado import testing

from tchannel import TChannel, Request, Response, schemes, errors
from tchannel import messages
from tchannel import io
from tchannel import frame
from tchannel.errors import AlreadyListeningError, TimeoutError
from tchannel.event import EventHook
from tchannel.response import TransportHeaders
Expand Down Expand Up @@ -376,3 +383,75 @@ def test_listen_duplicate_ports():
port = int(server.hostport.rsplit(":")[1])
server.listen(port)
server.listen()


def payload(message):
payload = messages.RW[message.message_type].write(
message,
io.BytesIO(),
).getvalue()

f = frame.Frame(
header=frame.FrameHeader(
message_type=message.message_type,
message_id=1,
),
payload=payload,
)

return bytes(frame.frame_rw.write(f, io.BytesIO()).getvalue())


@pytest.mark.parametrize('response', [
payload(messages.InitRequestMessage()),
payload(messages.InitResponseMessage(headers={'process_name': 'foo'})),
payload(messages.InitResponseMessage(headers={'host_port': '123'})),
])
@pytest.mark.gen_test
def test_client_doesnt_get_valid_handshake_response(response):

class BadServer(tcpserver.TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield stream.write(bytes(response))

(socket, port) = testing.bind_unused_port()

server = BadServer()
server.add_socket(socket)
server.start()

tchannel = TChannel(name='client')

with pytest.raises(errors.FatalProtocolError):
yield tchannel.call(
scheme=schemes.RAW,
service='server',
arg1='endpoint',
hostport='localhost:%d' % port,
)


@pytest.mark.parametrize('req', [
payload(messages.InitResponseMessage()),
payload(messages.InitResponseMessage(headers={'process_name': 'foo'})),
payload(messages.InitResponseMessage(headers={'host_port': '123'})),
])
@pytest.mark.gen_test
def test_server_doesnt_get_valid_handshake_request(req):
server = TChannel(name='server')
server.listen()

host, port = server.hostport.split(":")

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

bad_client = iostream.IOStream(sock)

yield bad_client.connect((host, int(port)))

# yield bad_client.write(payload(messages.InitResponseMessage()))
yield bad_client.write(req)

response = yield bad_client.read_until_close()
assert not response

0 comments on commit f20c46d

Please sign in to comment.