diff --git a/tchannel/tchannel.py b/tchannel/tchannel.py index cfd08f47..06287fa8 100644 --- a/tchannel/tchannel.py +++ b/tchannel/tchannel.py @@ -319,3 +319,6 @@ def advertise(self, routers=None, name=None, timeout=None, response = Response(json.loads(body), headers or {}) raise gen.Return(response) + + def stop(self, reason=None): + return self._dep_tchannel.stop(reason) diff --git a/tchannel/testing/vcr/patch.py b/tchannel/testing/vcr/patch.py index d596b800..f5109243 100644 --- a/tchannel/testing/vcr/patch.py +++ b/tchannel/testing/vcr/patch.py @@ -85,7 +85,7 @@ def send(self, arg1, arg2, arg3, vcr_request = VCRProxy.Request( serviceName=self.service.encode('utf-8'), hostPort=self.hostport, - knownPeers=self.original_tchannel.peers.hosts, + knownPeers=self.original_tchannel.peer_group.hosts, endpoint=endpoint, headers=(yield read_full(arg2)), body=(yield read_full(arg3)), diff --git a/tchannel/tornado/connection.py b/tchannel/tornado/connection.py index a2722135..83846b27 100644 --- a/tchannel/tornado/connection.py +++ b/tchannel/tornado/connection.py @@ -32,8 +32,8 @@ from .. import frame from .. import glossary from .. import messages -from ..errors import NetworkError from ..errors import FatalProtocolError +from ..errors import NetworkError from ..errors import TChannelError from ..event import EventType from ..io import BytesIO @@ -103,12 +103,9 @@ def __init__(self, connection, tchannel=None): # Tracks message IDs for this connection. self._id_sequence = 0 - # We need to use two separate message factories to avoid message ID - # collision while assembling fragmented messages. - self.request_message_factory = MessageFactory(self.remote_host, - self.remote_host_port) - self.response_message_factory = MessageFactory(self.remote_host, - self.remote_host_port) + self.message_factory = MessageFactory( + self.remote_host, self.remote_host_port + ) # Queue of unprocessed incoming calls. self._messages = queues.Queue() @@ -122,6 +119,15 @@ def __init__(self, connection, tchannel=None): self.tchannel = tchannel + # Map from message ID to inflight requests on the server side. + self.incoming_requests = {} + # Map from message ID to inflight responses on the client side. + # When client sends a request, it will first put ID and a future for + # response in the _outstanding map. After it gets response message, + # but not a completed response, it will put ID and response obj in the + # incoming_response. + self.incoming_responses = {} + connection.set_close_callback(self._on_close) def next_message_id(self): @@ -130,7 +136,6 @@ def next_message_id(self): def _on_close(self): self.closed = True - for message_id, future in self._outstanding.iteritems(): future.set_exception( NetworkError( @@ -138,7 +143,6 @@ def _on_close(self): ) ) self._outstanding = {} - try: while True: message = self._messages.get_nowait() @@ -192,6 +196,8 @@ def on_error(future): if isinstance(exception, tornado.iostream.StreamClosedError): self.close() + if message_future.running(): + message_future.set_exception(exception) size_width = frame.frame_rw.size_rw.width() self.connection.read_bytes(size_width).add_done_callback(on_read_size) @@ -205,50 +211,64 @@ def _loop(self): # # Must be started only after the handshake has been performed. self._loop_running = True + while not self.closed: - message = yield self._recv() - # TODO: There should probably be a try-catch on the yield. + try: + message = yield self._recv() + except tornado.iostream.StreamClosedError: + log.warning("Stream has been closed.") + break + if message.message_type in self.CALL_REQ_TYPES: self._messages.put(message) - continue - elif message.id in self._outstanding: - # set exception if receive error message - if message.message_type == Types.ERROR: - future = self._outstanding.pop(message.id) - if future.running(): - error = TChannelError.from_code( - message.code, - description=message.description, - ) - future.set_exception(error) - else: - protocol_exception = ( - self.response_message_factory.build(message) - ) - if protocol_exception: - self.event_emitter.fire( - EventType.after_receive_error, - protocol_exception, - ) - continue - - response = self.response_message_factory.build(message) - - # keep continue message in the list - # pop all other type messages including error message - if (message.message_type in self.CALL_RES_TYPES and - message.flags == FlagsType.fragment): - # still streaming, keep it for record - future = self._outstanding.get(message.id) - else: - future = self._outstanding.pop(message.id) - - if response and future.running(): - future.set_result(response) - continue - - log.warn('Unconsumed message %s', message) + # This should just handle CALL_RES or ERROR message. + # It is invalid if any CALL_RES_CONTINUE message goes here. + self._handle_response(message) + elif message.id in self.incoming_responses: + # This should just handle CALL_RES_CONTINUE or ERROR message. + # It is invalid if any CALL_RES message goes here. + self._handle_continuing_response(message) + else: + log.warn('Unconsumed message %s', message) + + def _handle_continuing_response(self, message): + response = self.incoming_responses.pop(message.id) + if message.message_type == Types.ERROR: + error = self.message_factory.build_inbound_error(message) + response.set_exception(error) + + self.tchannel.event_emitter.fire( + EventType.after_receive_error, error, + ) + return + + self.message_factory.build_inbound_response(message, response) + + if (message.message_type in self.CALL_RES_TYPES and + message.flags == FlagsType.fragment): + # still streaming, keep it for record + self.incoming_responses[message.id] = response + + def _handle_response(self, message): + future = self._outstanding.pop(message.id) + if message.message_type == Types.ERROR: + error = self.message_factory.build_inbound_error(message) + if future.running(): + future.set_exception(error) + + self.tchannel.event_emitter.fire( + EventType.after_receive_error, error, + ) + return + + response = self.message_factory.build_inbound_response(message, None) + if future.running(): + future.set_result(response) + if (message.message_type in self.CALL_RES_TYPES and + message.flags == FlagsType.fragment): + # still streaming, keep it for record + self.incoming_responses[message.id] = response # Basically, the only difference between send and write is that send # sets up a Future to get the response. That's ideal for peers making @@ -290,12 +310,7 @@ def write(self, message): message.id = message.id or self.next_message_id() - if message.message_type in self.CALL_REQ_TYPES: - message_factory = self.request_message_factory - else: - message_factory = self.response_message_factory - - fragments = message_factory.fragment(message) + fragments = self.message_factory.fragment(message) return chain(fragments, self._write) @@ -452,10 +467,8 @@ def serve(self, handler): """ assert handler, "handler is required" assert self._loop_running, "Finish the handshake first" - while not self.closed: message = yield self.await() - try: handler(message, self) except Exception: @@ -489,6 +502,16 @@ def ping(self): def pong(self): return self._write(messages.PingResponseMessage()) + def add_incoming_request(self, request): + self.incoming_requests[request.id] = request + + def get_incoming_request(self, id): + return self.incoming_requests.get(id, None) + + def remove_incoming_request(self, id): + req = self.incoming_requests.pop(id, None) + return req + class StreamConnection(TornadoConnection): """Streaming request/response into protocol messages and sent by tornado @@ -561,7 +584,7 @@ def _stream(self, context, message_factory): def post_response(self, response): try: # TODO: before_send_response - yield self._stream(response, self.response_message_factory) + yield self._stream(response, self.message_factory) # event: send_response self.tchannel.event_emitter.fire( @@ -575,7 +598,7 @@ def stream_request(self, request): """send the given request and response is not required""" request.close_argstreams() - stream_future = self._stream(request, self.request_message_factory) + stream_future = self._stream(request, self.message_factory) stream_future.add_done_callback( lambda f: request.close_argstreams(force=True), diff --git a/tchannel/tornado/dispatch.py b/tchannel/tornado/dispatch.py index a927d814..df60d52c 100644 --- a/tchannel/tornado/dispatch.py +++ b/tchannel/tornado/dispatch.py @@ -77,7 +77,6 @@ def __init__(self, _handler_returns_response=False): def handle(self, message, connection): # TODO assert that the handshake was already completed assert message, "message must not be None" - if message.message_type not in self._HANDLER_NAMES: # TODO handle this more gracefully raise NotImplementedError("Unexpected message: %s" % str(message)) @@ -101,13 +100,20 @@ def handle_pre_call(self, message, connection): :param connection: tornado connection """ try: - req = connection.request_message_factory.build(message) + new_req = connection.message_factory.build_inbound_request( + message, connection.get_incoming_request(message.id) + ) # message_factory will create Request only when it receives # CallRequestMessage. It will return None, if it receives # CallRequestContinueMessage. - if req: - self.handle_call(req, connection) - + if new_req: + # process the new request + connection.add_incoming_request(new_req) + self.handle_call(new_req, connection).add_done_callback( + lambda _: connection.remove_incoming_request( + new_req.id + ) + ) except TChannelError as e: log.warn('Received a bad request.', exc_info=True) @@ -127,13 +133,13 @@ def handle_call(self, request, connection): # request.endpoint. The original argstream[0] is no longer valid. If # user still tries read from it, it will return empty. chunk = yield request.argstreams[0].read() - response = None while chunk: request.endpoint += chunk chunk = yield request.argstreams[0].read() log.debug('Received a call to %s.', request.endpoint) + response = None tchannel = connection.tchannel # event: receive_request @@ -168,7 +174,6 @@ def handle_call(self, request, connection): headers={'as': request.headers.get('as', 'raw')}, serializer=handler.resp_serializer, ) - connection.post_response(response) try: @@ -220,17 +225,13 @@ def handle_call(self, request, connection): response.flush() except TChannelError as e: - connection.send_error( - e.code, - e.message, - request.id, - ) + response.set_exception(e) + connection.send_error(e.code, e.message, request.id) except Exception as e: msg = "An unexpected error has occurred from the handler" log.exception(msg) response.set_exception(TChannelError(e.message)) - connection.request_message_factory.remove_buffer(response.id) connection.send_error(ErrorCode.unexpected, msg, response.id) tchannel.event_emitter.fire(EventType.on_exception, request, e) diff --git a/tchannel/tornado/hyperbahn.py b/tchannel/tornado/hyperbahn.py index fbb5804b..27d88286 100644 --- a/tchannel/tornado/hyperbahn.py +++ b/tchannel/tornado/hyperbahn.py @@ -123,7 +123,7 @@ def advertise(tchannel, service, routers=None, timeout=None, router_file=None): for router in routers: # We use .get here instead of .add because we don't want to fail if a # TChannel already knows about some of the routers. - tchannel.peers.get(router) + tchannel.peer_group.get(router) result = yield _advertise_with_backoff( tchannel, service, timeout=timeout diff --git a/tchannel/tornado/message_factory.py b/tchannel/tornado/message_factory.py index 767fa820..e915b8b0 100644 --- a/tchannel/tornado/message_factory.py +++ b/tchannel/tornado/message_factory.py @@ -70,9 +70,6 @@ class MessageFactory(object): """ def __init__(self, remote_host=None, remote_host_port=None): - # key: message_id - # value: incomplete streaming messages - self.message_buffer = {} self.remote_host = remote_host self.remote_host_port = remote_host_port @@ -239,94 +236,120 @@ def build_response(self, message): ) return res - def build_context(self, message): - if message.message_type == Types.CALL_REQ: - return self.build_request(message) - elif message.message_type == Types.CALL_RES: - return self.build_response(message) + def build_inbound_response(self, message, response): + """buffer all the streaming messages based on the + message id. Reconstruct all fragments together. + + :param message: + incoming message + :param response: + incoming response + :return: next complete message or None if streaming or fragmentation + is not done + """ + if message.message_type == Types.CALL_RES: + self.verify_message(message) + response = self.build_response(message) + num = self._find_incompleted_stream(response) + self.close_argstream(response, num) + return response + + elif message.message_type == Types.CALL_RES_CONTINUE: + if response is None: + # missing call msg before continue msg + raise FatalProtocolError( + "missing call message after receiving continue message") - def build(self, message): + dst = self._find_incompleted_stream(response) + try: + self.verify_message(message) + except InvalidChecksumError as e: + response.argstreams[dst].set_exception(e) + raise + + src = 0 + while src < len(message.args): + response.argstreams[dst].write(message.args[src]) + dst += 1 + src += 1 + + if message.flags != FlagsType.fragment: + # get last fragment. mark it as completed + assert (len(response.argstreams) == + CallContinueMessage.max_args_num) + response.flags = FlagsType.none + + self.close_argstream(response, dst - 1) + return None + + @classmethod + def build_inbound_error(cls, message): + """convert error message to TChannelError type.""" + return TChannelError.from_code( + message.code, + description=message.description, + tracing=message.tracing + ) + + def build_inbound_request(self, message, request): """buffer all the streaming messages based on the message id. Reconstruct all fragments together. :param message: incoming message - :return: next complete message or None if streaming + :param request: + incoming request + :return: next complete message or None if streaming or fragmentation is not done """ - context = None + if message.message_type == Types.CALL_REQ: + if request: + raise FatalProtocolError( + "Already got an request with same message id.") - if message.message_type in [Types.CALL_REQ, - Types.CALL_RES]: self.verify_message(message) + request = self.build_request(message) + num = self._find_incompleted_stream(request) + self.close_argstream(request, num) + return request - context = self.build_context(message) - # streaming message - if message.flags == common.FlagsType.fragment: - self.message_buffer[message.id] = context - - # find the incompleted stream - num = 0 - for i, arg in enumerate(context.argstreams): - if arg.state != StreamState.completed: - num = i - break - - self.close_argstream(context, num) - return context - - elif message.message_type in [Types.CALL_REQ_CONTINUE, - Types.CALL_RES_CONTINUE]: - context = self.message_buffer.get(message.id) - if context is None: + if message.message_type == Types.CALL_REQ_CONTINUE: + if request is None: # missing call msg before continue msg raise FatalProtocolError( "missing call message after receiving continue message") - # find the incompleted stream - dst = 0 - for i, arg in enumerate(context.argstreams): - if arg.state != StreamState.completed: - dst = i - break - + dst = self._find_incompleted_stream(request) try: self.verify_message(message) except InvalidChecksumError as e: - context.argstreams[dst].set_exception(e) + request.argstreams[dst].set_exception(e) raise src = 0 while src < len(message.args): - context.argstreams[dst].write(message.args[src]) + request.argstreams[dst].write(message.args[src]) dst += 1 src += 1 if message.flags != FlagsType.fragment: # get last fragment. mark it as completed - assert (len(context.argstreams) == + assert (len(request.argstreams) == CallContinueMessage.max_args_num) - self.message_buffer.pop(message.id, None) - context.flags = FlagsType.none + request.flags = FlagsType.none - self.close_argstream(context, dst - 1) + self.close_argstream(request, dst - 1) return None - elif message.message_type == Types.ERROR: - context = self.message_buffer.pop(message.id, None) - if context is None: - log.warn('Unconsumed error %s', context) - return None - else: - error = TChannelError.from_code( - message.code, - description=message.description, - tracing=context.tracing, - ) - - context.set_exception(error) - return error - else: - return message + + @classmethod + def _find_incompleted_stream(cls, reqres): + # find the incompleted stream + num = 0 + for i, arg in enumerate(reqres.argstreams): + if arg.state != StreamState.completed: + num = i + break + return num def fragment(self, message): """Fragment message based on max payload size @@ -399,24 +422,3 @@ def close_argstream(request, num): for i in range(num): request.argstreams[i].close() - - def remove_buffer(self, message_id): - self.message_buffer.pop(message_id, None) - - def set_inbound_exception(self, protocol_error): - reqres = self.message_buffer.get(protocol_error.id) - if reqres is None: - # missing call msg before continue msg - raise FatalProtocolError( - "missing call message after receiving continue message") - - # find the incompleted stream - dst = 0 - for i, arg in enumerate(reqres.argstreams): - if arg.state != StreamState.completed: - dst = i - break - - reqres.argstreams[dst].set_exception(protocol_error) - - self.message_buffer.pop(protocol_error.id, None) diff --git a/tchannel/tornado/tchannel.py b/tchannel/tornado/tchannel.py index bba0834a..4acbc447 100644 --- a/tchannel/tornado/tchannel.py +++ b/tchannel/tornado/tchannel.py @@ -97,7 +97,7 @@ def __init__(self, name, hostport=None, process_name=None, else: self._handler = dispatcher - self.peers = PeerGroup(self) + self.peer_group = PeerGroup(self) self._port = 0 self._host = None @@ -123,7 +123,7 @@ def __init__(self, name, hostport=None, process_name=None, if known_peers: for peer_hostport in known_peers: - self.peers.add(peer_hostport) + self.peer_group.add(peer_hostport) # server created from calling listen() self._server = None @@ -183,7 +183,7 @@ def close(self): self._state = State.closing try: - yield self.peers.clear() + yield self.peer_group.clear() finally: self._state = State.closed @@ -221,11 +221,13 @@ def request(self, # TODO disallow certain parameters or don't propagate them backwards. # For example, blacklist and score threshold aren't really # user-configurable right now. - return self.peers.request(hostport=hostport, - service=service, - arg_scheme=arg_scheme, - retry=retry, - **kwargs) + return self.peer_group.request( + hostport=hostport, + service=service, + arg_scheme=arg_scheme, + retry=retry, + **kwargs + ) def listen(self, port=None): """Start listening for incoming connections. @@ -394,9 +396,16 @@ def hello_handler(request, response): else: return decorator + @tornado.gen.coroutine + def stop(self, reason=None, exempt=None): + self._server.drain() + yield [ + peer.stop(reason, exempt) for peer in self.peer_group.peers() + ] + class TChannelServer(tornado.tcpserver.TCPServer): - __slots__ = ('tchannel',) + __slots__ = ('tchannel', 'draining') def __init__(self, tchannel): super(TChannelServer, self).__init__() @@ -419,11 +428,10 @@ def handle_stream(self, stream, address): conn.remote_host_port, conn.remote_process_name) - self.tchannel.peers.get( + self.tchannel.peer_group.get( "%s:%s" % (conn.remote_host, conn.remote_host_port) ).register_incoming(conn) - yield conn.serve(handler=self._handle) def _handle(self, message, connection): diff --git a/tests/integration/test_retry.py b/tests/integration/test_retry.py index a46ab585..1cc76d82 100644 --- a/tests/integration/test_retry.py +++ b/tests/integration/test_retry.py @@ -69,7 +69,7 @@ def score(self): def chain(number_of_peers, endpoint): tchannel = TChannel(name='test') for i in range(number_of_peers): - p = tchannel.peers.get(server(endpoint).hostport) + p = tchannel.peer_group.get(server(endpoint).hostport) # Gaurantee error servers have score in order to pick first. p.state = FakeState() @@ -156,7 +156,7 @@ def test_retry_on_error_success(): tchannel_success = TChannel(name='test', hostport='localhost:0') tchannel_success.register(endpoint, 'raw', handler_success) tchannel_success.listen() - tchannel.peers.get(tchannel_success.hostport) + tchannel.peer_group.get(tchannel_success.hostport) with ( patch( diff --git a/tests/integration/tornado/test_connection_reuse.py b/tests/integration/tornado/test_connection_reuse.py index e729b4b6..50c60b51 100644 --- a/tests/integration/tornado/test_connection_reuse.py +++ b/tests/integration/tornado/test_connection_reuse.py @@ -61,10 +61,10 @@ def loop1(n): yield loop1(2) # Peer representing 2 for 1's point-of-view - peer_1_2 = server1.peers.lookup(hostport2) + peer_1_2 = server1.peer_group.lookup(hostport2) # Peer representing 1 from 2's point-of-view - peer_2_1 = server2.peers.lookup(hostport1) + peer_2_1 = server2.peer_group.lookup(hostport1) assert len(peer_1_2.outgoing_connections) == 1 assert len(peer_2_1.incoming_connections) == 1 diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 89dab7a8..dee9bb0d 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -70,7 +70,7 @@ def test_advertise_should_result_in_peer_connections(mock_server): assert result.headers == {} assert result.body == body - assert client._dep_tchannel.peers.hosts == routers + assert client._dep_tchannel.peer_group.hosts == routers def test_failing_advertise_should_raise(mock_server): diff --git a/tests/test_message_factory.py b/tests/test_message_factory.py index 8690d6fa..1afb38bc 100644 --- a/tests/test_message_factory.py +++ b/tests/test_message_factory.py @@ -19,9 +19,14 @@ # THE SOFTWARE. from __future__ import absolute_import -from tchannel.messages import CallRequestMessage, CallResponseMessage -from tchannel.messages.common import StreamState, FlagsType -from tchannel.tornado import Request, Response + +from tchannel.messages import CallRequestMessage +from tchannel.messages import CallResponseMessage +from tchannel.messages import ErrorMessage +from tchannel.messages.common import FlagsType +from tchannel.messages.common import StreamState +from tchannel.tornado import Request +from tchannel.tornado import Response from tchannel.tornado.message_factory import MessageFactory from tchannel.tornado.response import StatusCode from tchannel.zipkin.trace import Trace @@ -93,3 +98,12 @@ def test_build_response(): assert req.flags == message.flags assert req.headers == message.headers assert req.id == message.id + + +def test_build_inbound_error(): + message = ErrorMessage(code=0, tracing=Trace(), description="test") + error = MessageFactory.build_inbound_error(message) + + assert error.code == message.code + assert error.description == message.description + assert error.tracing == message.tracing diff --git a/tests/test_messages.py b/tests/test_messages.py index f0aa28f6..85d8e3ed 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -27,6 +27,7 @@ from tchannel import messages from tchannel.io import BytesIO from tchannel.messages import CallRequestMessage +from tchannel.messages import CallResponseMessage from tchannel.messages.common import PROTOCOL_VERSION from tchannel.tornado.message_factory import MessageFactory from tests.util import big_arg @@ -266,17 +267,45 @@ def test_equality_check_against_none(init_request_with_headers): ], ids=lambda arg: str(len(arg)) ) -def test_message_fragment(arg2, arg3, connection): +def test_message_fragment_request(arg2, arg3, connection): msg = CallRequestMessage(args=["", arg2, arg3]) origin_msg = CallRequestMessage(args=["", arg2, arg3]) message_factory = MessageFactory(connection) fragments = message_factory.fragment(msg) - recv_msg = None + request = None for fragment in fragments: - output = message_factory.build(fragment) + output = message_factory.build_inbound_request(fragment, request) if output: - recv_msg = output - header = yield recv_msg.get_header() - body = yield recv_msg.get_body() + request = output + header = yield request.get_header() + body = yield request.get_body() + assert header == origin_msg.args[1] + assert body == origin_msg.args[2] + + +@pytest.mark.gen_test +@pytest.mark.parametrize('arg2, arg3', [ + ("", big_arg()), + (big_arg(), ""), + ("test", big_arg()), + (big_arg(), "test"), + (big_arg(), big_arg()), + ("", ""), + ("test", "test"), +], + ids=lambda arg: str(len(arg)) +) +def test_message_fragment_response(arg2, arg3, connection): + msg = CallResponseMessage(args=["", arg2, arg3]) + origin_msg = CallResponseMessage(args=["", arg2, arg3]) + message_factory = MessageFactory(connection) + fragments = message_factory.fragment(msg) + response = None + for fragment in fragments: + output = message_factory.build_inbound_response(fragment, response) + if output: + response = output + header = yield response.get_header() + body = yield response.get_body() assert header == origin_msg.args[1] assert body == origin_msg.args[2] diff --git a/tests/tornado/test_connection.py b/tests/tornado/test_connection.py index 8d0b9067..47e6c303 100644 --- a/tests/tornado/test_connection.py +++ b/tests/tornado/test_connection.py @@ -22,6 +22,7 @@ import pytest import tornado.ioloop +from tornado.iostream import StreamClosedError import tornado.testing from tchannel.messages import Types @@ -61,3 +62,12 @@ def test_pings(self): pong = yield self.client.await() assert pong.message_type == Types.PING_RES + + @tornado.testing.gen_test + def test_close(self): + """Verify the error got thrown when connection is closed""" + self.client.ping() + yield self.server.await() + self.client.connection.close() + with pytest.raises(StreamClosedError): + yield self.server.await() diff --git a/tests/tornado/test_hyperbahn.py b/tests/tornado/test_hyperbahn.py index 89c1ca88..41d07eea 100644 --- a/tests/tornado/test_hyperbahn.py +++ b/tests/tornado/test_hyperbahn.py @@ -43,7 +43,7 @@ def test_new_client_establishes_peers(): ) for router in routers: - assert channel.peers.lookup(router) + assert channel.peer_group.lookup(router) def test_new_client_establishes_peers_from_file(): @@ -67,7 +67,7 @@ def test_new_client_establishes_peers_from_file(): with open(host_path, 'r') as json_data: routers = json.load(json_data) for router in routers: - assert channel.peers.lookup(router) + assert channel.peer_group.lookup(router) @pytest.mark.gen_test diff --git a/tests/tornado/test_tchannel.py b/tests/tornado/test_tchannel.py index 0c3989a3..7fcd8879 100644 --- a/tests/tornado/test_tchannel.py +++ b/tests/tornado/test_tchannel.py @@ -40,8 +40,8 @@ def peer(tchannel): @pytest.mark.gen_test def test_peer_caching(tchannel, peer): "Connections are long-lived and should not be recreated.""" - tchannel.peers.add(peer) - assert tchannel.peers.get("localhost:4040") is peer + tchannel.peer_group.add(peer) + assert tchannel.peer_group.get("localhost:4040") is peer def test_known_peers(): @@ -49,7 +49,7 @@ def test_known_peers(): tchannel = TChannel('test', known_peers=peers) for peer in peers: - assert tchannel.peers.lookup(peer) + assert tchannel.peer_group.lookup(peer) def test_is_listening_should_return_false_when_listen_not_called(tchannel):