Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

187 B3 Headers Support #188

Merged
merged 1 commit into from
Aug 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions baseplate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TraceInfo(_TraceInfo):
collecting the trace context and passing it along to the server span.

"""

@classmethod
def new(cls):
"""Generate IDs for a new initial server span.
Expand Down Expand Up @@ -135,6 +136,62 @@ def from_upstream(cls, trace_id, parent_id, span_id, sampled, flags):

return cls(trace_id, parent_id, span_id, sampled, flags)

@classmethod
def extract_upstream_header_values(cls, upstream_header_names, headers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense just to merge this into from_upstream ? I don't think we'd ever use one without the other?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I wanted to do that but there are subtle differences in how the two protocols handle the header values due to Thrift operating in bytes and HTTP using strings.

https://github.com/reddit/baseplate/pull/188/files#diff-db4cf04c08224f72cb9627e1f9866bbcR121
https://github.com/reddit/baseplate/pull/188/files#diff-d1382533c32ca95376190e31679acc05R162

Some type checking could be done. Or possibly a generic

True if extraced_values["sampled"] in ("1", b"1")

but this started to feel protocol specific again. If others are added, other protocol specific edge cases could arise making the from_upstream method kind of gnarly. I'm happy to take a stab at it though, if you'd like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 alrighty, makes sense. thanks

"""Extract values from upstream headers.

This method thinks about upstream headers by a general name as oppposed to the header
name, i.e. "trace_id" instead of "X-Trace". These general names are "trace_id",
"span_id", "parent_span_id", "sampled" and "flags".

A dict mapping these general names to corresponding header names is expected.

For example:

{
"trace_id": ("X-Trace", "X-B3-TraceId"),
"span_id": ("X-Span", "X-B3-SpanId"),
"parent_span_id": ("X-Parent", "X-B3-ParentSpanId"),
"sampled": ("X-Sampled", "X-B3-Sampled"),
"flags": ("X-Flags", "X-B3-Flags"),
}

This structure is used to extract relevant values from the request headers resulting
in a dict mapping general names to values.

For example:

{
"trace_id": "2391921232992245445",
"span_id": "7638783876913511395",
"parent_span_id": "3383915029748331832",
"sampled": "1",
}

:param dict upstream_headers_name: Map of general upstream value labels to header names
:param dict headers: Headers sent with a request
:return: Values found in upstream trace headers
:rtype: dict

:raises: :py:exc:`ValueError` if conflicting values are found for the same header category

"""
extracted_values = {}
for name, header_names in upstream_header_names.items():
values = []
for header_name in header_names:
if header_name in headers:
values.append(headers[header_name])

if not values:
continue
elif not all(value == values[0] for value in values):
raise ValueError("Conflicting values found for %s header(s)".format(header_names))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to crash the whole request if the headers are iffy? (i could see either answer being reasonable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I raised this to follow what is being done in new. Where this is called is wrapped so the request doesn't get nuked but rather the trace data is not propagated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, perfect. thanks!

else:
# All the values are the same
extracted_values[name] = values[0]
return extracted_values


class AuthenticationTokenValidator(object):
"""Factory that knows how to validate raw authentication tokens."""
Expand Down
35 changes: 21 additions & 14 deletions baseplate/integration/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def make_app(app_config):
from ..server import make_app


TRACE_HEADER_NAMES = {
"trace_id": ("X-Trace", "X-B3-TraceId"),
"span_id": ("X-Span", "X-B3-SpanId"),
"parent_span_id": ("X-Parent", "X-B3-ParentSpanId"),
"sampled": ("X-Sampled", "X-B3-Sampled"),
"flags": ("X-Flags", "X-B3-Flags"),
}


def _make_baseplate_tween(handler, registry):
def baseplate_tween(request):
try:
Expand Down Expand Up @@ -116,20 +125,7 @@ def _on_new_request(self, event):
trace_info = None
if self.trust_trace_headers:
try:
sampled = request.headers.get("X-Sampled", None)
if sampled is not None:
sampled = True if sampled == "1" else False
flags = request.headers.get("X-Flags", None)
if flags is not None:
flags = int(flags)
trace_info = TraceInfo.from_upstream(
trace_id=int(request.headers["X-Trace"]),
parent_id=int(request.headers["X-Parent"]),
span_id=int(request.headers["X-Span"]),
sampled=sampled,
flags=flags,
)

trace_info = self._get_trace_info(request.headers)
edge_payload = request.headers.get("X-Edge-Request", None)
if self.edge_context_factory:
edge_context = self.edge_context_factory.from_upstream(
Expand All @@ -156,6 +152,17 @@ def _start_server_span(self, request, name, trace_info=None):
request.trace.start()
request.registry.notify(ServerSpanInitialized(request))

def _get_trace_info(self, headers):
extracted_values = TraceInfo.extract_upstream_header_values(TRACE_HEADER_NAMES, headers)
flags = extracted_values.get("flags", None)
return TraceInfo.from_upstream(
int(extracted_values["trace_id"]),
int(extracted_values["parent_span_id"]),
int(extracted_values["span_id"]),
True if extracted_values["sampled"] == "1" else False,
int(flags) if flags is not None else None,
)

def includeme(self, config):
config.add_subscriber(self._on_new_request, pyramid.events.ContextFound)

Expand Down
36 changes: 21 additions & 15 deletions baseplate/integration/thrift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def make_processor(app_config):
from ...core import TraceInfo


TRACE_HEADER_NAMES = {
"trace_id": (b"Trace", b"B3-TraceId"),
"span_id": (b"Span", b"B3-SpanId"),
"parent_span_id": (b"Parent", b"B3-ParentSpanId"),
"sampled": (b"Sampled", b"B3-Sampled"),
"flags": (b"Flags", b"B3-Flags"),
}


class RequestContext(object):
pass

Expand Down Expand Up @@ -56,21 +65,7 @@ def getHandlerContext(self, fn_name, server_context):
trace_info = None
headers = server_context.iprot.trans.get_headers()
try:
sampled = headers.get(b"Sampled", None)
if sampled is not None:
sampled = True if sampled.decode('utf-8') == "1" else False
flags = headers.get(b"Flags", None)
if flags is not None:
flags = int(flags)

trace_info = TraceInfo.from_upstream(
trace_id=int(headers[b"Trace"]),
parent_id=int(headers[b"Parent"]),
span_id=int(headers[b"Span"]),
sampled=sampled,
flags=flags,
)

trace_info = self._get_trace_info(headers)
edge_payload = headers.get(b"Edge-Request", None)
if self.edge_context_factory:
edge_context = self.edge_context_factory.from_upstream(
Expand Down Expand Up @@ -115,3 +110,14 @@ def handlerError(self, handler_context, fn_name, exception):
handler_context.trace.finish(exc_info=sys.exc_info())
handler_context.trace.is_finished = True
self.logger.exception("Unexpected exception in %r.", fn_name)

def _get_trace_info(self, headers):
extracted_values = TraceInfo.extract_upstream_header_values(TRACE_HEADER_NAMES, headers)
flags = extracted_values.get("flags", None)
return TraceInfo.from_upstream(
int(extracted_values["trace_id"]),
int(extracted_values["parent_span_id"]),
int(extracted_values["span_id"]),
True if extracted_values["sampled"].decode("utf-8") == "1" else False,
int(flags) if flags is not None else None,
)
25 changes: 25 additions & 0 deletions tests/integration/pyramid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ def test_trace_headers(self):
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

def test_b3_trace_headers(self):
self.test_app.get("/example", headers={
"X-B3-TraceId": "1234",
"X-B3-ParentSpanId": "2345",
"X-B3-SpanId": "3456",
"X-B3-Sampled": "1",
"X-B3-Flags": "1",
})

self.assertEqual(self.observer.on_server_span_created.call_count, 1)

context, server_span = self.observer.on_server_span_created.call_args[0]
self.assertEqual(server_span.trace_id, 1234)
self.assertEqual(server_span.parent_id, 2345)
self.assertEqual(server_span.id, 3456)
self.assertEqual(server_span.sampled, True)
self.assertEqual(server_span.flags, 1)

with self.assertRaises(NoAuthenticationError):
context.request_context.user.id

self.assertTrue(self.server_observer.on_start.called)
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

def test_edge_request_headers(self):
self.test_app.get("/example", headers={
"X-Trace": "1234",
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/thrift_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,39 @@ def test_with_headers(self):
self.assertEqual(self.server_observer.on_finish.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_args[0], (None,))

def test_b3_trace_headers(self):
client_memory_trans = TMemoryBuffer()
client_prot = THeaderProtocol(client_memory_trans)
client_header_trans = client_prot.trans
client_header_trans.set_header("B3-TraceId", "1234")
client_header_trans.set_header("B3-ParentSpanId", "2345")
client_header_trans.set_header("B3-SpanId", "3456")
client_header_trans.set_header("B3-Sampled", "1")
client_header_trans.set_header("B3-Flags", "1")
client = TestService.Client(client_prot)
try:
client.example_simple()
except TTransportException:
pass # we don't have a test response for the client
self.itrans._readBuffer = StringIO(client_memory_trans.getvalue())

self.processor.process(self.iprot, self.oprot, self.server_context)
self.assertEqual(self.observer.on_server_span_created.call_count, 1)

context, server_span = self.observer.on_server_span_created.call_args[0]
self.assertEqual(server_span.trace_id, 1234)
self.assertEqual(server_span.parent_id, 2345)
self.assertEqual(server_span.id, 3456)
self.assertTrue(server_span.sampled)
self.assertEqual(server_span.flags, 1)

with self.assertRaises(NoAuthenticationError):
context.request_context.user.id

self.assertEqual(self.server_observer.on_start.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_count, 1)
self.assertEqual(self.server_observer.on_finish.call_args[0], (None,))

def test_edge_request_headers(self):
client_memory_trans = TMemoryBuffer()
client_prot = THeaderProtocol(client_memory_trans)
Expand Down