diff --git a/tchannel/tornado/connection.py b/tchannel/tornado/connection.py index d8620d0a..8ebc211f 100644 --- a/tchannel/tornado/connection.py +++ b/tchannel/tornado/connection.py @@ -124,13 +124,13 @@ def __init__(self, connection, tchannel=None, direction=None): self._messages = queues.Queue() # Map from message ID to futures for responses of outgoing calls. - self._out_pending_call = {} + self._outbound_pending_call = {} # Map from message ID to outgoing request that is being sent. - self._out_pending_req = {} + self._outbound_pending_req = {} # Map from message ID to outgoing response that is being sent. - self._out_pending_res = {} + self._outbound_pending_res = {} # Whether _loop is running. The loop doesn't run until after the # handshake has been performed. @@ -142,8 +142,9 @@ def __init__(self, connection, tchannel=None, direction=None): connection.set_close_callback(self._on_close) @property - def num_out_pendings(self): - return len(self._out_pending_res) + len(self._out_pending_req) + def total_outbound_pendings(self): + return (len(self._outbound_pending_res) + + len(self._outbound_pending_req)) def set_close_callback(self, cb): """Specify a function to be called when this connection is closed. @@ -164,13 +165,13 @@ def next_message_id(self): def _on_close(self): self.closed = True - for message_id, future in self._out_pending_call.iteritems(): + for message_id, future in self._outbound_pending_call.iteritems(): future.set_exception( NetworkError( "canceling outstanding request %d" % message_id ) ) - self._out_pending_call = {} + self._outbound_pending_call = {} try: while True: @@ -254,10 +255,10 @@ def _loop(self): self._messages.put(message) continue - elif message.id in self._out_pending_call: + elif message.id in self._outbound_pending_call: # set exception if receive error message if message.message_type == Types.ERROR: - future = self._out_pending_call.pop(message.id) + future = self._outbound_pending_call.pop(message.id) if future.running(): error = TChannelError.from_code( message.code, @@ -284,9 +285,9 @@ def _loop(self): if (message.message_type in self.CALL_RES_TYPES and message.flags == FlagsType.fragment): # still streaming, keep it for record - future = self._out_pending_call.get(message.id) + future = self._outbound_pending_call.get(message.id) else: - future = self._out_pending_call.pop(message.id) + future = self._outbound_pending_call.pop(message.id) if response and future.running(): future.set_result(response) @@ -312,12 +313,12 @@ def send(self, message): ) message.id = message.id or self.next_message_id() - assert message.id not in self._out_pending_call, ( + assert message.id not in self._outbound_pending_call, ( "Message ID '%d' already being used" % message.id ) future = tornado.gen.Future() - self._out_pending_call[message.id] = future + self._outbound_pending_call[message.id] = future self.write(message) return future @@ -607,7 +608,7 @@ def _stream(self, context, message_factory): @tornado.gen.coroutine def post_response(self, response): try: - self._out_pending_res[response.id] = response + self._outbound_pending_res[response.id] = response # TODO: before_send_response yield self._stream(response, self.response_message_factory) @@ -617,7 +618,7 @@ def post_response(self, response): response, ) finally: - self._out_pending_res.pop(response.id, None) + self._outbound_pending_res.pop(response.id, None) response.close_argstreams(force=True) def stream_request(self, request): @@ -645,15 +646,15 @@ def send_request(self, request): """ assert self._loop_running, "Perform a handshake first." - assert request.id not in self._out_pending_call, ( + assert request.id not in self._outbound_pending_call, ( "Message ID '%d' already being used" % request.id ) future = tornado.gen.Future() - self._out_pending_call[request.id] = future - self._out_pending_req[request.id] = request + self._outbound_pending_call[request.id] = future + self._outbound_pending_req[request.id] = request self.stream_request(request).add_done_callback( - lambda f: self._out_pending_req.pop(request.id, None) + lambda f: self._outbound_pending_req.pop(request.id, None) ) # the actual future that caller will yield @@ -682,4 +683,4 @@ def adapt_result(self, f, request, response_future): def remove_outstanding_request(self, request): """Remove request from pending request list""" - self._out_pending_call.pop(request.id, None) + self._outbound_pending_call.pop(request.id, None) diff --git a/tchannel/tornado/peer.py b/tchannel/tornado/peer.py index e57d7f7c..2a4b6741 100644 --- a/tchannel/tornado/peer.py +++ b/tchannel/tornado/peer.py @@ -180,6 +180,15 @@ def incoming_connections(self): takewhile(lambda c: c.direction == INCOMING, self._connections) ) + @property + def total_outbound_pendings(self): + """Return the total number of out pending req/res among connections""" + num = 0 + for con in self.connections: + num += con.total_outbound_pendings + + return num + @property def is_ephemeral(self): """Whether this Peer is ephemeral.""" diff --git a/tchannel/tornado/util.py b/tchannel/tornado/util.py index 181c63a9..3c2197e1 100644 --- a/tchannel/tornado/util.py +++ b/tchannel/tornado/util.py @@ -90,12 +90,3 @@ def go(): go() return all_done_future - - -def num_out_pendings(conns): - """Return the total number of out pending req/res among connections""" - num = 0 - for con in conns: - num += con.num_out_pendings - - return num diff --git a/tests/tornado/test_connection.py b/tests/tornado/test_connection.py index 2fc42f65..a498020a 100644 --- a/tests/tornado/test_connection.py +++ b/tests/tornado/test_connection.py @@ -28,7 +28,7 @@ from tchannel.messages import Types from tchannel import TChannel from tchannel.tornado.connection import StreamConnection -from tchannel.tornado.util import num_out_pendings +from tchannel.tornado.message_factory import MessageFactory def dummy_headers(): @@ -91,6 +91,7 @@ def test_pending_outgoing(): @server.raw.register def hello(request): + assert server._dep_tchannel.peers.peers[0].total_outbound_pendings == 1 return 'hi' client = TChannel('client') @@ -101,10 +102,29 @@ def hello(request): service='server' ) - assert num_out_pendings( - client._dep_tchannel.peers.get(server.hostport).connections - ) == 0 + client_peer = client._dep_tchannel.peers.peers[0] + server_peer = server._dep_tchannel.peers.peers[0] + assert client_peer.total_outbound_pendings == 0 + assert server_peer.total_outbound_pendings == 0 + + class FakeMessageFactory(MessageFactory): + def build_raw_message(self, context, args, is_completed=True): + assert client_peer.total_outbound_pendings == 1 + return super(FakeMessageFactory, self).build_raw_message( + context, args, is_completed, + ) + + client_conn = client_peer.connections[0] + client_conn.request_message_factory = FakeMessageFactory( + client_conn.remote_host, + client_conn.remote_host_port, + ) + yield client.raw( + hostport=server.hostport, + body='work', + endpoint='hello', + service='server' + ) - assert num_out_pendings( - server._dep_tchannel.peers.peers[0].connections - ) == 0 + assert client_peer.total_outbound_pendings == 0 + assert server_peer.total_outbound_pendings == 0