diff --git a/http-clients/src/main/java/com/palantir/remoting/http/FeignClientFactory.java b/http-clients/src/main/java/com/palantir/remoting/http/FeignClientFactory.java index 75c8c3d3ba..7de6bbf2a7 100644 --- a/http-clients/src/main/java/com/palantir/remoting/http/FeignClientFactory.java +++ b/http-clients/src/main/java/com/palantir/remoting/http/FeignClientFactory.java @@ -30,6 +30,7 @@ import feign.Feign; import feign.Logger.Level; import feign.Request; +import feign.TraceResponseDecoder; import feign.codec.Decoder; import feign.codec.Encoder; import feign.codec.ErrorDecoder; @@ -68,7 +69,7 @@ private FeignClientFactory( Request.Options options) { this.contract = contract; this.encoder = encoder; - this.decoder = decoder; + this.decoder = new TraceResponseDecoder(decoder); this.errorDecoder = errorDecoder; this.clientSupplier = clientSupplier; this.backoffStrategy = backoffStrategy; diff --git a/http-clients/src/main/java/feign/TraceResponseDecoder.java b/http-clients/src/main/java/feign/TraceResponseDecoder.java new file mode 100644 index 0000000000..445c72b33d --- /dev/null +++ b/http-clients/src/main/java/feign/TraceResponseDecoder.java @@ -0,0 +1,45 @@ +/* + * Copyright 2016 Palantir Technologies, Inc. All rights reserved. + */ + +package feign; + +import com.google.common.base.Optional; +import com.google.common.collect.Iterables; +import com.palantir.tracing.TraceState; +import com.palantir.tracing.Traces; +import feign.codec.DecodeException; +import feign.codec.Decoder; +import java.io.IOException; +import java.lang.reflect.Type; + +public final class TraceResponseDecoder implements Decoder { + + private final Decoder delegate; + + public TraceResponseDecoder(Decoder delegate) { + this.delegate = delegate; + } + + @Override + public Object decode(Response response, Type type) throws IOException, DecodeException, FeignException { + String traceId = safeGetOnlyElement(response.headers().get(Traces.Headers.TRACE_ID), null); + String spanId = safeGetOnlyElement(response.headers().get(Traces.Headers.SPAN_ID), null); + Optional trace = Traces.getTrace(); + if (traceId != null && spanId != null && trace.isPresent()) { + // there exists a trace, and the response included tracing information, so check the returned trace + // matches our current trace + if (trace.get().getTraceId().equals(traceId) + && trace.get().getSpanId().equals(spanId)) { + // this trace is for the traceId and spanId on top of the tracing stack, complete it + Traces.complete(); + } + } + return delegate.decode(response, type); + } + + private static T safeGetOnlyElement(Iterable iterable, T defaultValue) { + return iterable == null ? defaultValue : Iterables.getOnlyElement(iterable, defaultValue); + } + +} diff --git a/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestInterceptorTest.java b/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestInterceptorTest.java index a12b8fdaca..4360cdeb00 100644 --- a/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestInterceptorTest.java +++ b/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestInterceptorTest.java @@ -17,7 +17,6 @@ package com.palantir.remoting.http; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; import com.google.common.base.Optional; @@ -26,7 +25,6 @@ import com.squareup.okhttp.mockwebserver.MockResponse; import com.squareup.okhttp.mockwebserver.MockWebServer; import com.squareup.okhttp.mockwebserver.RecordedRequest; -import java.util.UUID; import javax.net.ssl.SSLSocketFactory; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -50,33 +48,24 @@ public void before() { endpointUri, TestRequestInterceptorService.class); - server.enqueue(new MockResponse().setBody("\"ok\"")); + server.enqueue(new MockResponse().setBody("{}")); } @Test public void testTraceRequestInterceptor_sendsAValidTraceId() throws InterruptedException { service.get(); - RecordedRequest request = server.takeRequest(); - - String traceId = request.getHeader(Traces.Headers.TRACE_ID); - assertThat(UUID.fromString(traceId).toString(), is(traceId)); - } + TraceState expectedTrace = Traces.getTrace().get(); - @Test - public void testTraceRequestInterceptor_sendsExplicitTraceId() throws InterruptedException { - TraceState state = Traces.deriveTrace("operation"); - service.get(); RecordedRequest request = server.takeRequest(); - - assertThat(request.getHeader(Traces.Headers.TRACE_ID), is(state.getTraceId())); - assertThat(request.getHeader(Traces.Headers.PARENT_SPAN_ID), is(state.getSpanId())); - assertThat(request.getHeader(Traces.Headers.SPAN_ID), not(state.getSpanId())); + assertThat(request.getHeader(Traces.Headers.TRACE_ID), is(expectedTrace.getTraceId())); + assertThat(request.getHeader(Traces.Headers.SPAN_ID), is(expectedTrace.getSpanId())); + assertThat(request.getHeader(Traces.Headers.PARENT_SPAN_ID), is(expectedTrace.getParentSpanId().orNull())); } @Path("/") public interface TestRequestInterceptorService { @GET - String get(); + Object get(); } } diff --git a/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestResponseTest.java b/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestResponseTest.java new file mode 100644 index 0000000000..863c22a6cd --- /dev/null +++ b/http-clients/src/test/java/com/palantir/remoting/http/TraceRequestResponseTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2016 Palantir Technologies, Inc. All rights reserved. + */ + +package com.palantir.remoting.http; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import com.google.common.base.Optional; +import com.palantir.remoting.http.server.TraceEnrichingFilter; +import com.palantir.tracing.TraceState; +import com.palantir.tracing.Traces; +import io.dropwizard.Application; +import io.dropwizard.Configuration; +import io.dropwizard.setup.Environment; +import io.dropwizard.testing.junit.DropwizardAppRule; +import javax.net.ssl.SSLSocketFactory; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public final class TraceRequestResponseTest { + + @ClassRule + public static final DropwizardAppRule APP = new DropwizardAppRule<>(TracingTestServer.class, + "src/test/resources/test-server.yml"); + + private TracingTestService service; + + @Before + public void before() { + String endpointUri = "http://localhost:" + APP.getLocalPort(); + + service = FeignClients.standard().createProxy( + Optional.absent(), + endpointUri, + TracingTestService.class); + } + + @Test + public void testTraceResponseDecoder_decoderPopsMatchingSpan() { + Optional before = Traces.getTrace(); + service.get(); + + assertThat(Traces.getTrace(), is(before)); + } + + public static class TracingTestServer extends Application { + @Override + public final void run(Configuration config, final Environment env) throws Exception { + env.jersey().register(new TraceEnrichingFilter()); + env.jersey().register(new TracingTestResource()); + } + } + + public static final class TracingTestResource implements TracingTestService { + @Override + public Object get() { + return "{}"; + } + } + + @Path("/") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public interface TracingTestService { + @GET + @Path("/trace") + Object get(); + } + +} diff --git a/http-clients/src/test/java/com/palantir/remoting/http/TraceResponseDecoderTest.java b/http-clients/src/test/java/com/palantir/remoting/http/TraceResponseDecoderTest.java new file mode 100644 index 0000000000..4d56d457a2 --- /dev/null +++ b/http-clients/src/test/java/com/palantir/remoting/http/TraceResponseDecoderTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2016 Palantir Technologies, Inc. All rights reserved. + */ + +package com.palantir.remoting.http; + +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; + +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.palantir.tracing.TraceState; +import com.palantir.tracing.Traces; +import feign.FeignException; +import feign.Response; +import feign.TraceResponseDecoder; +import feign.codec.DecodeException; +import feign.codec.Decoder; +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.Collection; +import org.junit.Test; + +public final class TraceResponseDecoderTest { + + @Test + public void testDecode_completesMatchingSpan() throws DecodeException, FeignException, IOException { + TraceState state = TraceState.builder() + .traceId("traceId") + .spanId("spanId") + .operation("any") + .build(); + Traces.setTrace(state); + + testDecodeInternal("traceId", "spanId"); + + assertThat(Traces.getTrace(), is(Optional.absent())); + } + + @Test + public void testDecode_doesNotPopNonMatchingSpan() throws DecodeException, FeignException, IOException { + TraceState state = TraceState.builder() + .traceId("traceId") + .spanId("spanId") + .operation("any") + .build(); + Traces.setTrace(state); + + testDecodeInternal("traceId", "otherSpanId"); + + assertThat(Traces.getTrace(), is(Optional.of(state))); + } + + @Test + public void testDecode_doesNotPopNonMatchingTrace() throws DecodeException, FeignException, IOException { + TraceState state = TraceState.builder() + .traceId("traceId") + .spanId("spanId") + .operation("any") + .build(); + Traces.setTrace(state); + + testDecodeInternal("otherTraceId", "spanId"); + + assertThat(Traces.getTrace(), is(Optional.of(state))); + } + + private void testDecodeInternal(String traceId, String spanId) throws IOException { + Decoder decoder = new TraceResponseDecoder(mock(Decoder.class)); + Response response = Response.create(200, + "", + ImmutableMap.of( + Traces.Headers.TRACE_ID, (Collection) ImmutableSet.of(traceId), + Traces.Headers.SPAN_ID, (Collection) ImmutableSet.of(spanId)), + new byte[0]); + + decoder.decode(response, mock(Type.class)); + } + +}