Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
Update test and rename variable
Browse files Browse the repository at this point in the history
  • Loading branch information
jc-fireball committed Dec 3, 2015
1 parent eaf9158 commit ddfa8f3
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 36 deletions.
41 changes: 21 additions & 20 deletions tchannel/tornado/connection.py
Expand Up @@ -124,13 +124,13 @@ def __init__(self, connection, tchannel=None, direction=None):
self._messages = queues.Queue()

# Map from message ID to futures for responses of outgoing calls.
self._out_pending_call = {}
self._outbound_pending_call = {}

# Map from message ID to outgoing request that is being sent.
self._out_pending_req = {}
self._outbound_pending_req = {}

# Map from message ID to outgoing response that is being sent.
self._out_pending_res = {}
self._outbound_pending_res = {}

# Whether _loop is running. The loop doesn't run until after the
# handshake has been performed.
Expand All @@ -142,8 +142,9 @@ def __init__(self, connection, tchannel=None, direction=None):
connection.set_close_callback(self._on_close)

@property
def num_out_pendings(self):
return len(self._out_pending_res) + len(self._out_pending_req)
def total_outbound_pendings(self):
return (len(self._outbound_pending_res) +
len(self._outbound_pending_req))

def set_close_callback(self, cb):
"""Specify a function to be called when this connection is closed.
Expand All @@ -164,13 +165,13 @@ def next_message_id(self):
def _on_close(self):
self.closed = True

for message_id, future in self._out_pending_call.iteritems():
for message_id, future in self._outbound_pending_call.iteritems():
future.set_exception(
NetworkError(
"canceling outstanding request %d" % message_id
)
)
self._out_pending_call = {}
self._outbound_pending_call = {}

try:
while True:
Expand Down Expand Up @@ -254,10 +255,10 @@ def _loop(self):
self._messages.put(message)
continue

elif message.id in self._out_pending_call:
elif message.id in self._outbound_pending_call:
# set exception if receive error message
if message.message_type == Types.ERROR:
future = self._out_pending_call.pop(message.id)
future = self._outbound_pending_call.pop(message.id)
if future.running():
error = TChannelError.from_code(
message.code,
Expand All @@ -284,9 +285,9 @@ def _loop(self):
if (message.message_type in self.CALL_RES_TYPES and
message.flags == FlagsType.fragment):
# still streaming, keep it for record
future = self._out_pending_call.get(message.id)
future = self._outbound_pending_call.get(message.id)
else:
future = self._out_pending_call.pop(message.id)
future = self._outbound_pending_call.pop(message.id)

if response and future.running():
future.set_result(response)
Expand All @@ -312,12 +313,12 @@ def send(self, message):
)

message.id = message.id or self.next_message_id()
assert message.id not in self._out_pending_call, (
assert message.id not in self._outbound_pending_call, (
"Message ID '%d' already being used" % message.id
)

future = tornado.gen.Future()
self._out_pending_call[message.id] = future
self._outbound_pending_call[message.id] = future
self.write(message)
return future

Expand Down Expand Up @@ -607,7 +608,7 @@ def _stream(self, context, message_factory):
@tornado.gen.coroutine
def post_response(self, response):
try:
self._out_pending_res[response.id] = response
self._outbound_pending_res[response.id] = response
# TODO: before_send_response
yield self._stream(response, self.response_message_factory)

Expand All @@ -617,7 +618,7 @@ def post_response(self, response):
response,
)
finally:
self._out_pending_res.pop(response.id, None)
self._outbound_pending_res.pop(response.id, None)
response.close_argstreams(force=True)

def stream_request(self, request):
Expand Down Expand Up @@ -645,15 +646,15 @@ def send_request(self, request):
"""
assert self._loop_running, "Perform a handshake first."

assert request.id not in self._out_pending_call, (
assert request.id not in self._outbound_pending_call, (
"Message ID '%d' already being used" % request.id
)

future = tornado.gen.Future()
self._out_pending_call[request.id] = future
self._out_pending_req[request.id] = request
self._outbound_pending_call[request.id] = future
self._outbound_pending_req[request.id] = request
self.stream_request(request).add_done_callback(
lambda f: self._out_pending_req.pop(request.id, None)
lambda f: self._outbound_pending_req.pop(request.id, None)
)

# the actual future that caller will yield
Expand Down Expand Up @@ -682,4 +683,4 @@ def adapt_result(self, f, request, response_future):

def remove_outstanding_request(self, request):
"""Remove request from pending request list"""
self._out_pending_call.pop(request.id, None)
self._outbound_pending_call.pop(request.id, None)
9 changes: 9 additions & 0 deletions tchannel/tornado/peer.py
Expand Up @@ -180,6 +180,15 @@ def incoming_connections(self):
takewhile(lambda c: c.direction == INCOMING, self._connections)
)

@property
def total_outbound_pendings(self):
"""Return the total number of out pending req/res among connections"""
num = 0
for con in self.connections:
num += con.total_outbound_pendings

return num

@property
def is_ephemeral(self):
"""Whether this Peer is ephemeral."""
Expand Down
9 changes: 0 additions & 9 deletions tchannel/tornado/util.py
Expand Up @@ -90,12 +90,3 @@ def go():
go()

return all_done_future


def num_out_pendings(conns):
"""Return the total number of out pending req/res among connections"""
num = 0
for con in conns:
num += con.num_out_pendings

return num
34 changes: 27 additions & 7 deletions tests/tornado/test_connection.py
Expand Up @@ -28,7 +28,7 @@
from tchannel.messages import Types
from tchannel import TChannel
from tchannel.tornado.connection import StreamConnection
from tchannel.tornado.util import num_out_pendings
from tchannel.tornado.message_factory import MessageFactory


def dummy_headers():
Expand Down Expand Up @@ -91,6 +91,7 @@ def test_pending_outgoing():

@server.raw.register
def hello(request):
assert server._dep_tchannel.peers.peers[0].total_outbound_pendings == 1
return 'hi'

client = TChannel('client')
Expand All @@ -101,10 +102,29 @@ def hello(request):
service='server'
)

assert num_out_pendings(
client._dep_tchannel.peers.get(server.hostport).connections
) == 0
client_peer = client._dep_tchannel.peers.peers[0]
server_peer = server._dep_tchannel.peers.peers[0]
assert client_peer.total_outbound_pendings == 0
assert server_peer.total_outbound_pendings == 0

class FakeMessageFactory(MessageFactory):
def build_raw_message(self, context, args, is_completed=True):
assert client_peer.total_outbound_pendings == 1
return super(FakeMessageFactory, self).build_raw_message(
context, args, is_completed,
)

client_conn = client_peer.connections[0]
client_conn.request_message_factory = FakeMessageFactory(
client_conn.remote_host,
client_conn.remote_host_port,
)
yield client.raw(
hostport=server.hostport,
body='work',
endpoint='hello',
service='server'
)

assert num_out_pendings(
server._dep_tchannel.peers.peers[0].connections
) == 0
assert client_peer.total_outbound_pendings == 0
assert server_peer.total_outbound_pendings == 0

0 comments on commit ddfa8f3

Please sign in to comment.