From 4e5b2fc1ccf536aea0e25ac9852b516fa17bb084 Mon Sep 17 00:00:00 2001 From: Bryce Lampe Date: Wed, 26 Aug 2015 13:58:31 -0700 Subject: [PATCH] make client_for compatible with new TChannel --- CHANGES.rst | 1 + tchannel/thrift/client.py | 38 ++++++++++++++++++++++++++---------- tests/schemes/test_thrift.py | 25 ++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 1b616c86..1250c996 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,7 @@ Changes by Version - Fixed a bug where the 'not found' handler would incorrectly return serialization mismatch errors.. +- Made ``client_for`` compatible with ``tchannel.TChannel``. 0.16.0 (2015-08-25) diff --git a/tchannel/thrift/client.py b/tchannel/thrift/client.py index 9f683475..5541f2b1 100644 --- a/tchannel/thrift/client.py +++ b/tchannel/thrift/client.py @@ -26,6 +26,7 @@ from thrift import Thrift from tornado import gen +from tchannel import schemes from tchannel.errors import OneWayNotSupportedError from ..serializer.thrift import ThriftSerializer @@ -166,17 +167,34 @@ def send(self, *args, **kwargs): body = serializer.serialize_body(call_args) header = serializer.serialize_header({}) - response = yield self.tchannel.request( - hostport=self.hostport, service=self.service - ).send( - arg1=endpoint, - arg2=header, - arg3=body, # body - headers=self.protocol_headers, - traceflag=self.trace - ) - body = yield response.get_body() + + # Glue for old API. + if hasattr(self.tchannel, 'request'): + response = yield self.tchannel.request( + hostport=self.hostport, service=self.service + ).send( + arg1=endpoint, + arg2=header, + arg3=body, # body + headers=self.protocol_headers, + traceflag=self.trace + ) + body = yield response.get_body() + else: + response = yield self.tchannel.call( + scheme=schemes.THRIFT, + service=self.service, + arg1=endpoint, + arg2=header, + arg3=body, + hostport=self.hostport, + #headers=self.protocol_headers, + #traceflag=self.trace, + ) + body = response.body + call_result = serializer.deserialize_body(body) + if not result_spec: # void return type and no exceptions allowed raise gen.Return(None) diff --git a/tests/schemes/test_thrift.py b/tests/schemes/test_thrift.py index 8dfa3896..1f02d3bd 100644 --- a/tests/schemes/test_thrift.py +++ b/tests/schemes/test_thrift.py @@ -35,8 +35,10 @@ from tchannel.errors import OneWayNotSupportedError from tchannel.errors import UnexpectedError from tchannel.errors import ValueExpectedError +from tchannel.thrift import client_for from tchannel.testing.data.generated.ThriftTest import SecondService from tchannel.testing.data.generated.ThriftTest import ThriftTest +from tchannel.tornado import TChannel as DeprecatedTChannel # TODO - where possible, in req/res style test, create parameterized tests, @@ -1183,3 +1185,26 @@ def test_headers_should_be_a_map_of_strings(headers): request=mock.MagicMock(), headers=headers, ) + + +@pytest.mark.gen_test +@pytest.mark.call +@pytest.mark.parametrize('ClientTChannel', [TChannel, DeprecatedTChannel]) +def test_client_for(ClientTChannel): + server = TChannel(name='server') + + @server.thrift.register(ThriftTest) + def testString(request): + return request.body.thing.encode('rot13') + + server.listen() + + tchannel = ClientTChannel(name='client') + + client = client_for('server', ThriftTest)( + tchannel=tchannel, + hostport=server.hostport, + ) + + resp = yield client.testString(thing='foo') + assert resp == 'sbb'