diff --git a/spring-cloud-sleuth-core/src/main/java/org/springframework/cloud/sleuth/instrument/web/client/TraceWebClientBeanPostProcessor.java b/spring-cloud-sleuth-core/src/main/java/org/springframework/cloud/sleuth/instrument/web/client/TraceWebClientBeanPostProcessor.java index e50f31fc07..97bb452a7e 100644 --- a/spring-cloud-sleuth-core/src/main/java/org/springframework/cloud/sleuth/instrument/web/client/TraceWebClientBeanPostProcessor.java +++ b/spring-cloud-sleuth-core/src/main/java/org/springframework/cloud/sleuth/instrument/web/client/TraceWebClientBeanPostProcessor.java @@ -16,18 +16,16 @@ package org.springframework.cloud.sleuth.instrument.web.client; -import java.util.Collections; import java.util.List; import java.util.concurrent.CancellationException; import java.util.function.Consumer; import java.util.function.Function; import brave.Span; -import brave.Tracer; -import brave.Tracing; import brave.http.HttpClientHandler; import brave.http.HttpTracing; -import brave.propagation.Propagation; +import brave.propagation.CurrentTraceContext; +import brave.propagation.CurrentTraceContext.Scope; import brave.propagation.TraceContext; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -38,11 +36,10 @@ import reactor.util.annotation.Nullable; import reactor.util.context.Context; -import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.cloud.sleuth.internal.LazyBean; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.web.client.RestClientException; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; @@ -60,21 +57,19 @@ */ final class TraceWebClientBeanPostProcessor implements BeanPostProcessor { - private final ConfigurableApplicationContext springContext; + final ConfigurableApplicationContext springContext; TraceWebClientBeanPostProcessor(ConfigurableApplicationContext springContext) { this.springContext = springContext; } @Override - public Object postProcessBeforeInitialization(Object bean, String beanName) - throws BeansException { + public Object postProcessBeforeInitialization(Object bean, String beanName) { return bean; } @Override - public Object postProcessAfterInitialization(Object bean, String beanName) - throws BeansException { + public Object postProcessAfterInitialization(Object bean, String beanName) { if (bean instanceof WebClient) { WebClient webClient = (WebClient) bean; return wrapBuilder(webClient.mutate()).build(); @@ -114,25 +109,6 @@ private boolean noneMatchTraceExchangeFunction( final class TraceExchangeFilterFunction implements ExchangeFilterFunction { private static final Log log = LogFactory.getLog(TraceExchangeFilterFunction.class); - static final Propagation.Setter SETTER = new Propagation.Setter() { - @Override - public void put(ClientRequest.Builder carrier, String key, String value) { - carrier.headers(httpHeaders -> { - if (log.isTraceEnabled()) { - log.trace("Replacing [" + key + "] with value [" + value + "]"); - } - httpHeaders.merge(key, Collections.singletonList(value), - (oldValue, newValue) -> newValue); - }); - } - - @Override - public String toString() { - return "ClientRequest.Builder::header"; - } - }; - - static final String CLIENT_SPAN_KEY = "sleuth.webclient.clientSpan"; static final Exception CANCELLED_ERROR = new CancellationException("CANCELLED") { @Override @@ -141,20 +117,17 @@ public Throwable fillInStackTrace() { } }; - final ConfigurableApplicationContext springContext; + final LazyBean httpTracing; final Function, ? extends Publisher> scopePassingTransformer; - Tracer tracer; - - HttpTracing httpTracing; - + // Lazy initialized fields HttpClientHandler handler; - TraceContext.Injector injector; + CurrentTraceContext currentTraceContext; TraceExchangeFilterFunction(ConfigurableApplicationContext springContext) { - this.springContext = springContext; + this.httpTracing = LazyBean.create(springContext, HttpTracing.class); this.scopePassingTransformer = scopePassingSpanOperator(springContext); } @@ -166,80 +139,55 @@ public static ExchangeFilterFunction create( @Override public Mono filter(ClientRequest request, ExchangeFunction next) { HttpClientRequest wrapper = new HttpClientRequest(request); + TraceContext parent = currentTraceContext().get(); + Span clientSpan = handler().handleSend(wrapper); if (log.isDebugEnabled()) { - log.debug("Instrumenting WebClient call"); + log.debug("HttpClientHandler::handleSend: " + clientSpan); } - Span parentSpan = tracer().currentSpan(); - Span span = handler().handleSend(wrapper); - if (log.isDebugEnabled()) { - log.debug("Handled send of " + span); + return new MonoWebClientTrace(next, wrapper.buildRequest(), this, parent, + clientSpan); + } + + CurrentTraceContext currentTraceContext() { + if (this.currentTraceContext == null) { + this.currentTraceContext = httpTracing.get().tracing().currentTraceContext(); } - MonoWebClientTrace trace = new MonoWebClientTrace(next, wrapper.buildRequest(), - this, span); - // TODO: investigate why this commit leaks a scope: - // 8f5bcdabd7af23df443e771432eb85597f3b3076 - tracer().withSpanInScope(parentSpan); - return trace; + return this.currentTraceContext; } - @SuppressWarnings("unchecked") HttpClientHandler handler() { if (this.handler == null) { - this.handler = HttpClientHandler - .create(this.springContext.getBean(HttpTracing.class)); + this.handler = HttpClientHandler.create(this.httpTracing.get()); } return this.handler; } - Tracer tracer() { - if (this.tracer == null) { - this.tracer = httpTracing().tracing().tracer(); - } - return this.tracer; - } - - HttpTracing httpTracing() { - if (this.httpTracing == null) { - this.httpTracing = this.springContext.getBean(HttpTracing.class); - } - return this.httpTracing; - } - - TraceContext.Injector injector() { - if (this.injector == null) { - this.injector = this.springContext.getBean(HttpTracing.class).tracing() - .propagation().injector(SETTER); - } - return this.injector; - } - private static final class MonoWebClientTrace extends Mono { final ExchangeFunction next; final ClientRequest request; - final Tracer tracer; - final HttpClientHandler handler; - final TraceContext.Injector injector; - - final Tracing tracing; + final CurrentTraceContext currentTraceContext; final Function, ? extends Publisher> scopePassingTransformer; + @Nullable + final TraceContext parent; + private final Span span; MonoWebClientTrace(ExchangeFunction next, ClientRequest request, - TraceExchangeFilterFunction parent, Span span) { + TraceExchangeFilterFunction filterFunction, @Nullable TraceContext parent, + Span span) { this.next = next; this.request = request; - this.tracer = parent.tracer(); - this.handler = parent.handler(); - this.injector = parent.injector(); - this.tracing = parent.httpTracing().tracing(); - this.scopePassingTransformer = parent.scopePassingTransformer; + this.handler = filterFunction.handler(); + this.currentTraceContext = filterFunction.currentTraceContext(); + this.scopePassingTransformer = filterFunction.scopePassingTransformer; + this.parent = parent; this.span = span; } @@ -248,177 +196,133 @@ public void subscribe(CoreSubscriber subscriber) { Context context = subscriber.currentContext(); - this.next.exchange(request).subscribe( - new WebClientTracerSubscriber(subscriber, context, span, this)); + this.next.exchange(request).subscribe(new WebClientTracerSubscriber( + subscriber, context, parent, span, this)); } - static final class WebClientTracerSubscriber - implements CoreSubscriber { + } - final CoreSubscriber actual; + private static final class WebClientTracerSubscriber + implements CoreSubscriber { - final Context context; + final CoreSubscriber actual; - final Span span; + final Context context; - final HttpClientHandler handler; + @Nullable + final TraceContext parent; - final Function, ? extends Publisher> scopePassingTransformer; + final Span clientSpan; - final Tracing tracing; + final HttpClientHandler handler; - boolean done; + final Function, ? extends Publisher> scopePassingTransformer; - WebClientTracerSubscriber(CoreSubscriber actual, - Context context, Span span, MonoWebClientTrace parent) { - this.actual = actual; - this.span = span; - this.handler = parent.handler; - this.tracing = parent.tracing; - this.scopePassingTransformer = parent.scopePassingTransformer; + final CurrentTraceContext currentTraceContext; + + boolean done; + + WebClientTracerSubscriber(CoreSubscriber actual, + Context ctx, @Nullable final TraceContext parent, Span clientSpan, + MonoWebClientTrace mono) { + this.actual = actual; + this.parent = parent; + this.clientSpan = clientSpan; + this.handler = mono.handler; + this.currentTraceContext = mono.currentTraceContext; + this.scopePassingTransformer = mono.scopePassingTransformer; + this.context = parent != null + && !parent.equals(ctx.getOrDefault(TraceContext.class, null)) + ? ctx.put(TraceContext.class, parent) : ctx; + } - if (!context.hasKey(TraceContext.class)) { - context = context.put(TraceContext.class, span.context()); - if (log.isDebugEnabled()) { - log.debug("Reactor Context got injected with the client span " - + span); + @Override + public void onSubscribe(Subscription subscription) { + this.actual.onSubscribe(new Subscription() { + @Override + public void request(long n) { + try (Scope scope = currentTraceContext.maybeScope(parent)) { + subscription.request(n); } } - this.context = context.put(CLIENT_SPAN_KEY, span); - } - - @Override - public void onSubscribe(Subscription subscription) { - this.actual.onSubscribe(new Subscription() { - @Override - public void request(long n) { - try (Tracer.SpanInScope ws = tracing.tracer() - .withSpanInScope(span)) { - if (log.isTraceEnabled()) { - log.trace("Request"); - } - subscription.request(n); - } + @Override + public void cancel() { + try (Scope scope = currentTraceContext.maybeScope(parent)) { + subscription.cancel(); } - - @Override - public void cancel() { - try (Tracer.SpanInScope ws = tracing.tracer() - .withSpanInScope(span)) { - if (log.isTraceEnabled()) { - log.trace("Cancel"); - } - terminateSpanOnCancel(); - subscription.cancel(); + finally { + if (log.isDebugEnabled()) { + log.debug("Subscription was cancelled. Will close the span [" + + clientSpan + "]"); } + handleReceive(null, CANCELLED_ERROR); } - }); - } + } + }); + } - @Override - public void onNext(ClientResponse response) { - try (Tracer.SpanInScope ws = tracing.tracer().withSpanInScope(span)) { - this.done = true; - try { - // decorate response body - this.actual.onNext(ClientResponse.from(response) + @Override + public void onNext(ClientResponse response) { + try (Scope scope = currentTraceContext.maybeScope(parent)) { + this.done = true; + // decorate response body + this.actual + .onNext(ClientResponse.from(response) .body(response.bodyToFlux(DataBuffer.class) .transform(this.scopePassingTransformer)) .build()); - } - finally { - terminateSpan(response, null); - } - } - } - - @Override - public void onError(Throwable t) { - try (Tracer.SpanInScope ws = tracing.tracer().withSpanInScope(span)) { - try { - this.actual.onError(t); - } - finally { - terminateSpan(null, t); - } - } } - - @Override - public void onComplete() { - try (Tracer.SpanInScope ws = tracing.tracer().withSpanInScope(span)) { - try { - this.actual.onComplete(); - } - finally { - if (!this.done) { - terminateSpan(null, null); - } - } - } + finally { + handleReceive(response, null); } + } - @Override - public Context currentContext() { - return this.context; + @Override + public void onError(Throwable t) { + try (Scope scope = currentTraceContext.maybeScope(parent)) { + this.actual.onError(t); } - - void handleReceive(Span clientSpan, @Nullable ClientResponse res, - @Nullable Throwable error) { - if (log.isTraceEnabled()) { - log.trace("Handling receive"); - } - HttpClientResponse response = res != null ? new HttpClientResponse(res) - : null; - this.handler.handleReceive(response, error, clientSpan); - if (log.isTraceEnabled()) { - log.trace("Closed scope"); - } + finally { + handleReceive(null, t); } + } - void terminateSpanOnCancel() { - if (log.isDebugEnabled()) { - log.debug("Subscription was cancelled. Will close the span [" - + this.span + "]"); - } - - handleReceive(this.span, null, CANCELLED_ERROR); + @Override + public void onComplete() { + try (Scope scope = currentTraceContext.maybeScope(parent)) { + this.actual.onComplete(); } - - void terminateSpan(@Nullable ClientResponse clientResponse, - @Nullable Throwable error) { - if (clientResponse == null) { + finally { + // TODO: onComplete should be after onNext. Why are we handling this? + if (!this.done) { // unknown state if (log.isDebugEnabled()) { - log.debug("No response was returned. Will close the span [" - + this.span + "]"); + log.debug("Reached OnComplete without finishing [" + + this.clientSpan + "]"); } - handleReceive(this.span, null, error); - return; + this.clientSpan.abandon(); } - int statusCode = clientResponse.rawStatusCode(); - boolean isHttpError = statusCode >= 400; - if (isHttpError) { - if (log.isDebugEnabled()) { - log.debug( - "Non positive status code was returned from the call. Will close the span [" - + this.span + "]"); - } - error = new RestClientException( - "Status code of the response is [" + statusCode + "]"); - } - handleReceive(this.span, clientResponse, error); } + } + @Override + public Context currentContext() { + return this.context; + } + + void handleReceive(@Nullable ClientResponse res, @Nullable Throwable error) { + HttpClientResponse response = res != null ? new HttpClientResponse(res) + : null; + this.handler.handleReceive(response, error, clientSpan); } } - static final class HttpClientRequest extends brave.http.HttpClientRequest { + private static final class HttpClientRequest extends brave.http.HttpClientRequest { - private final ClientRequest delegate; + final ClientRequest delegate; - private final ClientRequest.Builder builder; + final ClientRequest.Builder builder; HttpClientRequest(ClientRequest delegate) { this.delegate = delegate; @@ -463,7 +367,7 @@ ClientRequest buildRequest() { static final class HttpClientResponse extends brave.http.HttpClientResponse { - private final ClientResponse delegate; + final ClientResponse delegate; HttpClientResponse(ClientResponse delegate) { this.delegate = delegate; @@ -476,12 +380,8 @@ public Object unwrap() { @Override public int statusCode() { - try { - return delegate.rawStatusCode(); - } - catch (Exception dontCare) { - return 0; - } + // unlike statusCode(), this doesn't throw + return Math.max(delegate.rawStatusCode(), 0); } } diff --git a/spring-cloud-sleuth-core/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/TraceExchangeFilterFunctionHttpClientResponseTests.java b/spring-cloud-sleuth-core/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/TraceExchangeFilterFunctionHttpClientResponseTests.java index 031edc2352..b194cf6abc 100644 --- a/spring-cloud-sleuth-core/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/TraceExchangeFilterFunctionHttpClientResponseTests.java +++ b/spring-cloud-sleuth-core/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/TraceExchangeFilterFunctionHttpClientResponseTests.java @@ -20,6 +20,7 @@ import org.junit.Test; import org.mockito.BDDMockito; +import org.springframework.cloud.sleuth.instrument.web.client.TraceExchangeFilterFunction.HttpClientResponse; import org.springframework.web.reactive.function.client.ClientResponse; public class TraceExchangeFilterFunctionHttpClientResponseTests { @@ -27,10 +28,8 @@ public class TraceExchangeFilterFunctionHttpClientResponseTests { @Test public void should_return_0_when_invalid_status_code_is_returned() { ClientResponse clientResponse = BDDMockito.mock(ClientResponse.class); - BDDMockito.given(clientResponse.rawStatusCode()) - .willThrow(new IllegalStateException("Boom")); - TraceExchangeFilterFunction.HttpClientResponse response = new TraceExchangeFilterFunction.HttpClientResponse( - clientResponse); + BDDMockito.given(clientResponse.rawStatusCode()).willReturn(-1); + HttpClientResponse response = new HttpClientResponse(clientResponse); Integer statusCode = response.statusCode(); @@ -41,8 +40,7 @@ public void should_return_0_when_invalid_status_code_is_returned() { public void should_return_status_code_when_valid_status_code_is_returned() { ClientResponse clientResponse = BDDMockito.mock(ClientResponse.class); BDDMockito.given(clientResponse.rawStatusCode()).willReturn(200); - TraceExchangeFilterFunction.HttpClientResponse response = new TraceExchangeFilterFunction.HttpClientResponse( - clientResponse); + HttpClientResponse response = new HttpClientResponse(clientResponse); Integer statusCode = response.statusCode(); diff --git a/tests/spring-cloud-sleuth-instrumentation-reactor-tests/pom.xml b/tests/spring-cloud-sleuth-instrumentation-reactor-tests/pom.xml index 2e07b1f318..85fd25178f 100644 --- a/tests/spring-cloud-sleuth-instrumentation-reactor-tests/pom.xml +++ b/tests/spring-cloud-sleuth-instrumentation-reactor-tests/pom.xml @@ -59,6 +59,17 @@ org.springframework.cloud spring-cloud-starter-sleuth + + io.zipkin.brave + brave-instrumentation-http-tests + test + + + org.eclipse.jetty + * + + + org.springframework.boot spring-boot-starter-test diff --git a/tests/spring-cloud-sleuth-instrumentation-reactor-tests/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/WebClientBraveTests.java b/tests/spring-cloud-sleuth-instrumentation-reactor-tests/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/WebClientBraveTests.java new file mode 100644 index 0000000000..9d113a5bfd --- /dev/null +++ b/tests/spring-cloud-sleuth-instrumentation-reactor-tests/src/test/java/org/springframework/cloud/sleuth/instrument/web/client/WebClientBraveTests.java @@ -0,0 +1,180 @@ +/* + * Copyright 2013-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.sleuth.instrument.web.client; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import brave.http.HttpTracing; +import brave.test.http.ITHttpAsyncClient; +import io.netty.channel.ChannelOption; +import io.netty.handler.timeout.ReadTimeoutHandler; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.netty.http.client.HttpClient; +import reactor.util.context.Context; +import zipkin2.Callback; + +import org.springframework.beans.factory.config.BeanPostProcessor; +import org.springframework.cloud.sleuth.instrument.reactor.ScopePassingSpanSubscriberTests; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * This runs Brave's integration tests without underlying instrumentation, which would + * happen when a 3rd party client like Jetty is in use. + */ +public class WebClientBraveTests extends ITHttpAsyncClient { + + @Before + @After + public void resetHooks() { + new ScopePassingSpanSubscriberTests().resetHooks(); + } + + /** + * This uses Spring to instrument the {@link WebClient} using a + * {@link BeanPostProcessor}. + */ + @Override + protected WebClient newClient(int port) { + AnnotationConfigApplicationContext result = new AnnotationConfigApplicationContext(); + result.registerBean(HttpTracing.class, () -> httpTracing); + result.register(WebClientBuilderConfiguration.class); + result.register(TraceWebClientBeanPostProcessor.class); + result.refresh(); + return result.getBean(WebClient.Builder.class).baseUrl("http://127.0.0.1:" + port) + .build(); + } + + @Override + protected void closeClient(WebClient client) { + // WebClient is not Closeable + } + + @Override + protected void get(WebClient client, String pathIncludingQuery) { + client.get().uri(pathIncludingQuery).exchange().block(); + } + + @Override + protected void post(WebClient client, String pathIncludingQuery, String body) { + client.post().uri(pathIncludingQuery).body(BodyInserters.fromValue(body)) + .exchange().block(); + } + + @Override + protected void getAsync(WebClient client, String path, Callback callback) { + Mono request = client.get().uri(path).exchange(); + + request.subscribe(new CoreSubscriber() { + + final AtomicReference ref = new AtomicReference<>(); + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(ref.getAndSet(s), s)) { + s.request(Long.MAX_VALUE); + } + else { + s.cancel(); + } + } + + @Override + public void onNext(ClientResponse t) { + Subscription s = ref.getAndSet(null); + if (s != null) { + callback.onSuccess(null); + s.cancel(); + } + else { + Operators.onNextDropped(t, currentContext()); + } + } + + @Override + public void onError(Throwable t) { + if (ref.getAndSet(null) != null) { + callback.onError(t); + } + } + + @Override + public void onComplete() { + if (ref.getAndSet(null) != null) { + callback.onSuccess(null); + } + } + + @Override + public Context currentContext() { + return Context.empty(); + } + }); + } + + @Test + @Ignore("TODO: reactor/reactor-netty#1000") + @Override + public void redirect() { + } + + @Test + @Ignore("WebClient has no portable function to retrieve the server address") + @Override + public void reportsServerAddress() { + } + + /** + * This fakes auto-configuration which wouldn't configure reactor's trace + * instrumentation. + */ + @Configuration + static class WebClientBuilderConfiguration { + + @Bean + HttpClient httpClient() { + // TODO: ReactorNettyHttpClientBraveTests.testHttpClient() #1554 + return HttpClient.create() + .tcpConfiguration(tcpClient -> tcpClient + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 1000) + .doOnConnected(conn -> conn.addHandler( + new ReadTimeoutHandler(1, TimeUnit.SECONDS)))) + .followRedirect(true); + } + + @Bean + WebClient.Builder webClientBuilder(HttpClient httpClient) { + return WebClient.builder() + .clientConnector(new ReactorClientHttpConnector(httpClient)); + } + + } + +}