From b803562177d71bca0cfda39ac88da66da3614f6b Mon Sep 17 00:00:00 2001 From: Bryce Lampe Date: Mon, 24 Aug 2015 13:06:11 -0700 Subject: [PATCH] update zipkin to use new api --- tchannel/request.py | 4 ++ tchannel/schemes/json.py | 15 ++++- tchannel/schemes/raw.py | 17 +++++- tchannel/schemes/thrift.py | 14 ++++- tchannel/tchannel.py | 35 ++++++++++- tchannel/tornado/tchannel.py | 2 +- tchannel/zipkin/tracers.py | 11 ++-- tests/integration/json/test_json_server.py | 13 ++--- tests/integration/test_error_handling.py | 4 +- tests/integration/trace/test_zipkin_trace.py | 61 ++++++++++---------- tests/mock_server.py | 26 ++------- tests/testing/vcr/test_server.py | 2 +- 12 files changed, 127 insertions(+), 77 deletions(-) diff --git a/tchannel/request.py b/tchannel/request.py index a38187be..1c05c6a4 100644 --- a/tchannel/request.py +++ b/tchannel/request.py @@ -66,8 +66,10 @@ class TransportHeaders(object): 'retry_flags', 'scheme', 'speculative_exe', + 'shard_key', ) + def __init__(self, caller_name=None, claim_at_start=None, @@ -76,6 +78,7 @@ def __init__(self, retry_flags=None, scheme=None, speculative_exe=None, + shard_key=None, **kwargs): if scheme is None: @@ -88,3 +91,4 @@ def __init__(self, self.retry_flags = retry_flags self.scheme = scheme self.speculative_exe = speculative_exe + self.shard_key = shard_key diff --git a/tchannel/schemes/json.py b/tchannel/schemes/json.py index 230e886a..59f66742 100644 --- a/tchannel/schemes/json.py +++ b/tchannel/schemes/json.py @@ -38,8 +38,18 @@ def __init__(self, tchannel): self._tchannel = tchannel @gen.coroutine - def __call__(self, service, endpoint, body=None, headers=None, - timeout=None, retry_on=None, retry_limit=None, hostport=None): + def __call__( + self, + service, + endpoint, + body=None, + headers=None, + timeout=None, + retry_on=None, + retry_limit=None, + hostport=None, + shard_key=None, + ): """Make JSON TChannel Request. .. code-block: python @@ -93,6 +103,7 @@ def __call__(self, service, endpoint, body=None, headers=None, retry_on=retry_on, retry_limit=retry_limit, hostport=hostport, + shard_key=shard_key, ) # deserialize diff --git a/tchannel/schemes/raw.py b/tchannel/schemes/raw.py index 64086091..f1e4bb0f 100644 --- a/tchannel/schemes/raw.py +++ b/tchannel/schemes/raw.py @@ -33,10 +33,23 @@ class RawArgScheme(object): def __init__(self, tchannel): self._tchannel = tchannel - def __call__(self, service, endpoint, body=None, headers=None, - timeout=None, retry_on=None, retry_limit=None, hostport=None): + def __call__( + self, + service, + endpoint, + body=None, + headers=None, + timeout=None, + retry_on=None, + retry_limit=None, + hostport=None, + shard_key=None, + ): """Make Raw TChannel Request. + The request's headers and body are treated as raw bytes and not + serialized/deserialized. + .. code-block: python from tchannel import TChannel diff --git a/tchannel/schemes/thrift.py b/tchannel/schemes/thrift.py index 1773081c..166979ed 100644 --- a/tchannel/schemes/thrift.py +++ b/tchannel/schemes/thrift.py @@ -40,8 +40,15 @@ def __init__(self, tchannel): self._tchannel = tchannel @gen.coroutine - def __call__(self, request, headers=None, timeout=None, - retry_on=None, retry_limit=None): + def __call__( + self, + request, + headers=None, + timeout=None, + retry_on=None, + retry_limit=None, + shard_key=None, + ): if not headers: headers = {} @@ -66,7 +73,8 @@ def __call__(self, request, headers=None, timeout=None, timeout=timeout, retry_on=retry_on, retry_limit=retry_limit, - hostport=request.hostport + hostport=request.hostport, + shard_key=shard_key, ) # deserialize... diff --git a/tchannel/tchannel.py b/tchannel/tchannel.py index f133c06c..a2799d82 100644 --- a/tchannel/tchannel.py +++ b/tchannel/tchannel.py @@ -39,7 +39,22 @@ class TChannel(object): """Make requests to TChannel services.""" def __init__(self, name, hostport=None, process_name=None, - known_peers=None, trace=False): + known_peers=None, trace=True): + """ + **Note:** In general only one ``TChannel`` instance should be used at a + time. Multiple ``TChannel`` instances are not advisable and could + result in undefined behavior. + + :param string name: + How this application identifies itself. This is the name callers + will use to make contact, it is also what your downstream services + will see in their metrics. + + :param string hostport: + An optional host/port to serve on, e.g., ``"127.0.0.1:5555``. If + not provided an ephemeral port will be used. When advertising on + Hyperbahn you callers do not need to know your port. + """ # until we move everything here, # lets compose the old tchannel @@ -64,8 +79,19 @@ def hooks(self): return self._dep_tchannel.hooks @gen.coroutine - def call(self, scheme, service, arg1, arg2=None, arg3=None, - timeout=None, retry_on=None, retry_limit=None, hostport=None): + def call( + self, + scheme, + service, + arg1, + arg2=None, + arg3=None, + timeout=None, + retry_on=None, + retry_limit=None, + hostport=None, + shard_key=None, + ): """Make low-level requests to TChannel services. This method uses TChannel's protocol terminology for param naming. @@ -135,6 +161,9 @@ def call(self, scheme, service, arg1, arg2=None, arg3=None, transport.SCHEME: scheme, transport.CALLER_NAME: self.name, } + if shard_key: + transport_headers[transport.SHARD_KEY] = shard_key + response = yield operation.send( arg1=arg1, arg2=arg2, diff --git a/tchannel/tornado/tchannel.py b/tchannel/tornado/tchannel.py index 8ec72883..b57688d8 100644 --- a/tchannel/tornado/tchannel.py +++ b/tchannel/tornado/tchannel.py @@ -65,7 +65,7 @@ class TChannel(object): FALLBACK = RequestDispatcher.FALLBACK def __init__(self, name, hostport=None, process_name=None, - known_peers=None, trace=False, dispatcher=None): + known_peers=None, trace=True, dispatcher=None): """Build or re-use a TChannel. :param name: diff --git a/tchannel/zipkin/tracers.py b/tchannel/zipkin/tracers.py index d284d49e..09d5a675 100644 --- a/tchannel/zipkin/tracers.py +++ b/tchannel/zipkin/tracers.py @@ -43,13 +43,13 @@ from .formatters import i64_to_base64 from .thrift import TCollector from .thrift import constants -from ..thrift import client_for +from ..thrift import thrift_request_builder log = logging.getLogger('zipkin_tracing') zipkin_log = logging.getLogger('zipkin') -TCollectorClient = client_for('tcollector', TCollector) +TCollectorClient = thrift_request_builder('tcollector', TCollector) class EndAnnotationTracer(object): @@ -152,11 +152,10 @@ def submit_callback(f): fus = [] for (trace, annotations) in traces: - client = TCollectorClient( - self._tchannel, - protocol_headers={'shardKey': i64_to_base64(trace.trace_id)} + f = self._tchannel.thrift( + TCollectorClient.submit(thrift_formatter(trace, annotations)), + shard_key=i64_to_base64(trace.trace_id), ) - f = client.submit(thrift_formatter(trace, annotations)) f.add_done_callback(submit_callback) fus.append(f) diff --git a/tests/integration/json/test_json_server.py b/tests/integration/json/test_json_server.py index 83473ed2..3ba68bce 100644 --- a/tests/integration/json/test_json_server.py +++ b/tests/integration/json/test_json_server.py @@ -24,7 +24,7 @@ import tornado import tornado.gen -from tchannel import TChannel +from tchannel import TChannel, Response from tchannel.schemes import JSON from tests.mock_server import MockServer @@ -69,14 +69,13 @@ def sample_json(): def register(tchannel): - @tchannel.register("json_echo", "json") + @tchannel.json.register("json_echo") @tornado.gen.coroutine - def json_echo(request, response): - header = yield request.get_header() - body = yield request.get_body() + def json_echo(request): + headers = request.headers + body = request.body - response.write_header(header) - response.write_body(body) + return Response(body, headers) @pytest.yield_fixture diff --git a/tests/integration/test_error_handling.py b/tests/integration/test_error_handling.py index 21c03e97..e7723213 100644 --- a/tests/integration/test_error_handling.py +++ b/tests/integration/test_error_handling.py @@ -46,8 +46,8 @@ def handler2(request, response): def register(tchannel): - tchannel.register("endpoint1", "raw", handler1) - tchannel.register("endpoint2", "raw", handler2) + tchannel.register(endpoint="endpoint1", scheme="raw", handler=handler1) + tchannel.register(endpoint="endpoint2", scheme="raw", handler=handler2) @pytest.fixture diff --git a/tests/integration/trace/test_zipkin_trace.py b/tests/integration/trace/test_zipkin_trace.py index 72792746..54923cba 100644 --- a/tests/integration/trace/test_zipkin_trace.py +++ b/tests/integration/trace/test_zipkin_trace.py @@ -25,12 +25,11 @@ import tornado import tornado.gen -from tchannel.tornado import TChannel -from tchannel.tornado.stream import InMemStream +from tchannel import TChannel, Response from tchannel.zipkin.annotation import Endpoint from tchannel.zipkin.annotation import client_send from tchannel.zipkin.thrift import TCollector -from tchannel.zipkin.thrift.ttypes import Response +from tchannel.zipkin.thrift.ttypes import Response as TResponse from tchannel.zipkin.trace import Trace from tchannel.zipkin.tracers import TChannelZipkinTracer from tchannel.zipkin.zipkin_trace import ZipkinTraceHook @@ -42,10 +41,11 @@ from StringIO import StringIO -def submit(request, response): - span = request.args.span - r = Response() - r.ok = request.transport.headers['shardKey'] == base64.b64encode( +def submit(request): + span = request.body.span + r = TResponse() + + r.ok = request.transport.shard_key == base64.b64encode( span.traceId ) return r @@ -54,26 +54,24 @@ def submit(request, response): @pytest.fixture def register(tchannel): @tornado.gen.coroutine - def handler2(request, response): - response.set_body_s(InMemStream("from handler2")) + def handler2(request): + return "from handler2" @tornado.gen.coroutine - def handler1(request, response): - header = yield request.get_header() - res = yield tchannel.request(header).send( - "endpoint2", - "", - "", - traceflag=True + def handler1(request): + hostport = request.headers + + res = yield tchannel.raw( + service='handler2', + hostport=hostport, + endpoint="endpoint2", ) - body = yield res.get_body() - yield response.write_header("from handler1") - yield response.write_body(body) - response.flush() - tchannel.register("endpoint1", "raw", handler1) - tchannel.register("endpoint2", "raw", handler2) - tchannel.register(TCollector, "thrift", submit) + raise tornado.gen.Return(Response(res.body, "from handler1")) + + tchannel.register(endpoint="endpoint1", scheme="raw", handler=handler1) + tchannel.register(endpoint="endpoint2", scheme="raw", handler=handler2) + tchannel.register(endpoint=TCollector, scheme="thrift", handler=submit) trace_buf = StringIO() @@ -99,12 +97,15 @@ def test_zipkin_trace(trace_server): hostport = 'localhost:%d' % trace_server.port - response = yield tchannel.request(hostport).send(InMemStream(endpoint), - InMemStream(hostport), - InMemStream(), - traceflag=True) - header = yield response.get_header() - body = yield response.get_body() + response = yield tchannel.raw( + service='test-client', + hostport=hostport, + endpoint=endpoint, + headers=hostport, + ) + + header = response.headers + body = response.body assert header == "from handler1" assert body == "from handler2" traces = [] @@ -132,4 +133,4 @@ def test_tcollector_submit(trace_server): results = yield TChannelZipkinTracer(tchannel).record([(trace, anns)]) - assert results[0].ok + assert results[0].body.ok is True diff --git a/tests/mock_server.py b/tests/mock_server.py index 0eb097f0..f0debee6 100644 --- a/tests/mock_server.py +++ b/tests/mock_server.py @@ -69,8 +69,7 @@ def execute(request, response): def and_result(self, result): def execute(request, response): - response.body = result - #response.write_result(result) + return result self.execute = execute return self @@ -83,22 +82,6 @@ def execute(request, response): self.execute = execute return self - def and_error(self, protocoal_error): - - def execute(request, response): - # send error message for test purpose only - connection = response.connection - connection.send_error( - protocoal_error.code, - protocoal_error.description, - response.id, - ) - # stop normal response streams - response.set_exception(TChannelError("stop stream")) - - self.execute = execute - return self - def times(self, count): self.execute = _LimitCount(self.execute, count) return self @@ -143,8 +126,11 @@ def handle_expected_endpoint(request): response = Response() return expectation.execute(request, response) - getattr(self.tchannel, scheme).register(endpoint, **kwargs)( - handle_expected_endpoint + self.tchannel.register( + scheme=scheme, + endpoint=endpoint, + handler=handle_expected_endpoint, + **kwargs ) return expectation diff --git a/tests/testing/vcr/test_server.py b/tests/testing/vcr/test_server.py index 0145ccf0..47be6595 100644 --- a/tests/testing/vcr/test_server.py +++ b/tests/testing/vcr/test_server.py @@ -161,7 +161,7 @@ def test_protocol_error(vcr_service, cassette, call, mock_server): allow(cassette).can_replay.and_return(False) expect(cassette).record.never() - mock_server.expect_call('endpoint').and_error( + mock_server.expect_call('endpoint').and_raise( TChannelError.from_code(1, description='great sadness') )