Skip to content
This repository has been archived by the owner on Jan 5, 2024. It is now read-only.

Commit

Permalink
refactor message_factory class
Browse files Browse the repository at this point in the history
  • Loading branch information
jc-fireball committed Sep 24, 2015
1 parent bcecb95 commit 06fc526
Show file tree
Hide file tree
Showing 15 changed files with 276 additions and 186 deletions.
3 changes: 3 additions & 0 deletions tchannel/tchannel.py
Expand Up @@ -319,3 +319,6 @@ def advertise(self, routers=None, name=None, timeout=None,
response = Response(json.loads(body), headers or {})

raise gen.Return(response)

def stop(self, reason=None):
return self._dep_tchannel.stop(reason)
2 changes: 1 addition & 1 deletion tchannel/testing/vcr/patch.py
Expand Up @@ -85,7 +85,7 @@ def send(self, arg1, arg2, arg3,
vcr_request = VCRProxy.Request(
serviceName=self.service.encode('utf-8'),
hostPort=self.hostport,
knownPeers=self.original_tchannel.peers.hosts,
knownPeers=self.original_tchannel.peer_group.hosts,
endpoint=endpoint,
headers=(yield read_full(arg2)),
body=(yield read_full(arg3)),
Expand Down
141 changes: 82 additions & 59 deletions tchannel/tornado/connection.py
Expand Up @@ -32,8 +32,8 @@
from .. import frame
from .. import glossary
from .. import messages
from ..errors import NetworkError
from ..errors import FatalProtocolError
from ..errors import NetworkError
from ..errors import TChannelError
from ..event import EventType
from ..io import BytesIO
Expand Down Expand Up @@ -103,12 +103,9 @@ def __init__(self, connection, tchannel=None):
# Tracks message IDs for this connection.
self._id_sequence = 0

# We need to use two separate message factories to avoid message ID
# collision while assembling fragmented messages.
self.request_message_factory = MessageFactory(self.remote_host,
self.remote_host_port)
self.response_message_factory = MessageFactory(self.remote_host,
self.remote_host_port)
self.message_factory = MessageFactory(
self.remote_host, self.remote_host_port
)

# Queue of unprocessed incoming calls.
self._messages = queues.Queue()
Expand All @@ -122,6 +119,15 @@ def __init__(self, connection, tchannel=None):

self.tchannel = tchannel

# Map from message ID to inflight requests on the server side.
self.incoming_requests = {}
# Map from message ID to inflight responses on the client side.
# When client sends a request, it will first put ID and a future for
# response in the _outstanding map. After it gets response message,
# but not a completed response, it will put ID and response obj in the
# incoming_response.
self.incoming_responses = {}

connection.set_close_callback(self._on_close)

def next_message_id(self):
Expand All @@ -130,15 +136,13 @@ def next_message_id(self):

def _on_close(self):
self.closed = True

for message_id, future in self._outstanding.iteritems():
future.set_exception(
NetworkError(
"canceling outstanding request %d" % message_id
)
)
self._outstanding = {}

try:
while True:
message = self._messages.get_nowait()
Expand Down Expand Up @@ -192,6 +196,8 @@ def on_error(future):

if isinstance(exception, tornado.iostream.StreamClosedError):
self.close()
if message_future.running():
message_future.set_exception(exception)

size_width = frame.frame_rw.size_rw.width()
self.connection.read_bytes(size_width).add_done_callback(on_read_size)
Expand All @@ -205,50 +211,64 @@ def _loop(self):
#
# Must be started only after the handshake has been performed.
self._loop_running = True

while not self.closed:
message = yield self._recv()
# TODO: There should probably be a try-catch on the yield.
try:
message = yield self._recv()
except tornado.iostream.StreamClosedError:
log.warning("Stream has been closed.")
break

if message.message_type in self.CALL_REQ_TYPES:
self._messages.put(message)
continue

elif message.id in self._outstanding:
# set exception if receive error message
if message.message_type == Types.ERROR:
future = self._outstanding.pop(message.id)
if future.running():
error = TChannelError.from_code(
message.code,
description=message.description,
)
future.set_exception(error)
else:
protocol_exception = (
self.response_message_factory.build(message)
)
if protocol_exception:
self.event_emitter.fire(
EventType.after_receive_error,
protocol_exception,
)
continue

response = self.response_message_factory.build(message)

# keep continue message in the list
# pop all other type messages including error message
if (message.message_type in self.CALL_RES_TYPES and
message.flags == FlagsType.fragment):
# still streaming, keep it for record
future = self._outstanding.get(message.id)
else:
future = self._outstanding.pop(message.id)

if response and future.running():
future.set_result(response)
continue

log.warn('Unconsumed message %s', message)
# This should just handle CALL_RES or ERROR message.
# It is invalid if any CALL_RES_CONTINUE message goes here.
self._handle_response(message)
elif message.id in self.incoming_responses:
# This should just handle CALL_RES_CONTINUE or ERROR message.
# It is invalid if any CALL_RES message goes here.
self._handle_continuing_response(message)
else:
log.warn('Unconsumed message %s', message)

def _handle_continuing_response(self, message):
response = self.incoming_responses.pop(message.id)
if message.message_type == Types.ERROR:
error = self.message_factory.build_inbound_error(message)
response.set_exception(error)

self.tchannel.event_emitter.fire(
EventType.after_receive_error, error,
)
return

self.message_factory.build_inbound_response(message, response)

if (message.message_type in self.CALL_RES_TYPES and
message.flags == FlagsType.fragment):
# still streaming, keep it for record
self.incoming_responses[message.id] = response

def _handle_response(self, message):
future = self._outstanding.pop(message.id)
if message.message_type == Types.ERROR:
error = self.message_factory.build_inbound_error(message)
if future.running():
future.set_exception(error)

self.tchannel.event_emitter.fire(
EventType.after_receive_error, error,
)
return

response = self.message_factory.build_inbound_response(message, None)
if future.running():
future.set_result(response)
if (message.message_type in self.CALL_RES_TYPES and
message.flags == FlagsType.fragment):
# still streaming, keep it for record
self.incoming_responses[message.id] = response

# Basically, the only difference between send and write is that send
# sets up a Future to get the response. That's ideal for peers making
Expand Down Expand Up @@ -290,12 +310,7 @@ def write(self, message):

message.id = message.id or self.next_message_id()

if message.message_type in self.CALL_REQ_TYPES:
message_factory = self.request_message_factory
else:
message_factory = self.response_message_factory

fragments = message_factory.fragment(message)
fragments = self.message_factory.fragment(message)

return chain(fragments, self._write)

Expand Down Expand Up @@ -452,10 +467,8 @@ def serve(self, handler):
"""
assert handler, "handler is required"
assert self._loop_running, "Finish the handshake first"

while not self.closed:
message = yield self.await()

try:
handler(message, self)
except Exception:
Expand Down Expand Up @@ -489,6 +502,16 @@ def ping(self):
def pong(self):
return self._write(messages.PingResponseMessage())

def add_incoming_request(self, request):
self.incoming_requests[request.id] = request

def get_incoming_request(self, id):
return self.incoming_requests.get(id, None)

def remove_incoming_request(self, id):
req = self.incoming_requests.pop(id, None)
return req


class StreamConnection(TornadoConnection):
"""Streaming request/response into protocol messages and sent by tornado
Expand Down Expand Up @@ -561,7 +584,7 @@ def _stream(self, context, message_factory):
def post_response(self, response):
try:
# TODO: before_send_response
yield self._stream(response, self.response_message_factory)
yield self._stream(response, self.message_factory)

# event: send_response
self.tchannel.event_emitter.fire(
Expand All @@ -575,7 +598,7 @@ def stream_request(self, request):
"""send the given request and response is not required"""
request.close_argstreams()

stream_future = self._stream(request, self.request_message_factory)
stream_future = self._stream(request, self.message_factory)

stream_future.add_done_callback(
lambda f: request.close_argstreams(force=True),
Expand Down
27 changes: 14 additions & 13 deletions tchannel/tornado/dispatch.py
Expand Up @@ -77,7 +77,6 @@ def __init__(self, _handler_returns_response=False):
def handle(self, message, connection):
# TODO assert that the handshake was already completed
assert message, "message must not be None"

if message.message_type not in self._HANDLER_NAMES:
# TODO handle this more gracefully
raise NotImplementedError("Unexpected message: %s" % str(message))
Expand All @@ -101,13 +100,20 @@ def handle_pre_call(self, message, connection):
:param connection: tornado connection
"""
try:
req = connection.request_message_factory.build(message)
new_req = connection.message_factory.build_inbound_request(
message, connection.get_incoming_request(message.id)
)
# message_factory will create Request only when it receives
# CallRequestMessage. It will return None, if it receives
# CallRequestContinueMessage.
if req:
self.handle_call(req, connection)

if new_req:
# process the new request
connection.add_incoming_request(new_req)
self.handle_call(new_req, connection).add_done_callback(
lambda _: connection.remove_incoming_request(
new_req.id
)
)
except TChannelError as e:
log.warn('Received a bad request.', exc_info=True)

Expand All @@ -127,13 +133,13 @@ def handle_call(self, request, connection):
# request.endpoint. The original argstream[0] is no longer valid. If
# user still tries read from it, it will return empty.
chunk = yield request.argstreams[0].read()
response = None
while chunk:
request.endpoint += chunk
chunk = yield request.argstreams[0].read()

log.debug('Received a call to %s.', request.endpoint)

response = None
tchannel = connection.tchannel

# event: receive_request
Expand Down Expand Up @@ -168,7 +174,6 @@ def handle_call(self, request, connection):
headers={'as': request.headers.get('as', 'raw')},
serializer=handler.resp_serializer,
)

connection.post_response(response)

try:
Expand Down Expand Up @@ -220,17 +225,13 @@ def handle_call(self, request, connection):

response.flush()
except TChannelError as e:
connection.send_error(
e.code,
e.message,
request.id,
)
response.set_exception(e)
connection.send_error(e.code, e.message, request.id)
except Exception as e:
msg = "An unexpected error has occurred from the handler"
log.exception(msg)

response.set_exception(TChannelError(e.message))
connection.request_message_factory.remove_buffer(response.id)

connection.send_error(ErrorCode.unexpected, msg, response.id)
tchannel.event_emitter.fire(EventType.on_exception, request, e)
Expand Down
2 changes: 1 addition & 1 deletion tchannel/tornado/hyperbahn.py
Expand Up @@ -123,7 +123,7 @@ def advertise(tchannel, service, routers=None, timeout=None, router_file=None):
for router in routers:
# We use .get here instead of .add because we don't want to fail if a
# TChannel already knows about some of the routers.
tchannel.peers.get(router)
tchannel.peer_group.get(router)

result = yield _advertise_with_backoff(
tchannel, service, timeout=timeout
Expand Down

0 comments on commit 06fc526

Please sign in to comment.