Skip to content

Commit

Permalink
Merge pull request #2902 from crankydillo/reactor-netty-client-provid…
Browse files Browse the repository at this point in the history
…er-fix

Avoid a payload byte[] copy when using reactor-netty HTTP client engine
  • Loading branch information
jamezp committed Sep 23, 2021
2 parents a1adb8a + 0afd328 commit a8e8042
Show file tree
Hide file tree
Showing 6 changed files with 681 additions and 60 deletions.
@@ -1,6 +1,5 @@
package org.jboss.resteasy.client.jaxrs.engines;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
Expand All @@ -21,8 +20,11 @@
import javax.ws.rs.client.ResponseProcessingException;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.Providers;

import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.group.ChannelGroup;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import static java.util.Objects.requireNonNull;

Expand All @@ -31,8 +33,13 @@
import org.jboss.resteasy.client.jaxrs.internal.ClientInvocation;
import org.jboss.resteasy.client.jaxrs.internal.ClientRequestHeaders;
import org.jboss.resteasy.client.jaxrs.internal.ClientResponse;
import org.jboss.resteasy.client.jaxrs.internal.TrackingClientRequestHeaders;
import org.jboss.resteasy.core.ResteasyContext;
import org.jboss.resteasy.tracing.RESTEasyTracingLogger;
import static org.jboss.resteasy.util.HttpHeaderNames.CONTENT_LENGTH;

import org.jboss.resteasy.util.CaseInsensitiveMap;
import org.jboss.resteasy.client.jaxrs.internal.TrackingMap;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientResponse;
Expand Down Expand Up @@ -111,42 +118,8 @@ public ReactorNettyClientHttpEngine(final HttpClient httpClient,
public <T> Mono<T> submitRx(final ClientInvocation request,
final boolean buffered,
final ResultExtractor<T> extractor) {
final Optional<byte[]> payload =
Optional.ofNullable(request.getEntity()).map(entity -> requestContent(request));

final HttpClient.RequestSender requestSender =
httpClient
.headers(headerBuilder -> {
final ClientRequestHeaders resteasyHeaders = request.getHeaders();
resteasyHeaders.getHeaders().entrySet().forEach(entry -> {
final String key = entry.getKey();
final List<Object> valueList = entry.getValue();
valueList.forEach(value -> headerBuilder.add(key, value != null ? value : ""));
});

payload.ifPresent(bytes -> {

headerBuilder.set(CONTENT_LENGTH, bytes.length);

if (log.isDebugEnabled() &&
isContentLengthInvalid(resteasyHeaders.getHeader(CONTENT_LENGTH), bytes)) {
log.debug("The request's Content-Length header is replaced " +
" by the size of the byte array computed from the request entity.");
}
});
})
.request(HttpMethod.valueOf(request.getMethod()))
.uri(request.getUri().toString());

// Please see https://github.com/reactor/reactor-netty/issues/585 to see why
// we do not use outbound.sendObject(object) API.
final HttpClient.ResponseReceiver<?> responseReceiver =
payload.<HttpClient.ResponseReceiver<?>>map(bytes -> requestSender.send(
(httpClientRequest, outbound) ->
outbound.sendByteArray(Mono.just(bytes)))
).orElse(requestSender);

final Mono<ClientResponse> responseMono = responseReceiver
final Mono<ClientResponse> responseMono =
send(request)
.responseSingle((response, bytes) -> bytes
.asInputStream()
.map(is -> toRestEasyResponse(request.getClientConfiguration(), response, is))
Expand All @@ -173,7 +146,7 @@ public <T> Mono<T> submitRx(final ClientInvocation request,
);

return requestTimeout
.map(duration -> responseMono.timeout(duration))
.map(responseMono::timeout)
.orElse(responseMono)
.handle((response, sink) -> {
try {
Expand All @@ -196,6 +169,81 @@ public <T> Mono<T> submitRx(final ClientInvocation request,
});
}

/**
* The main business logic mapping RestEasy's {@link ClientInvocation request} to Reactor Netty's concept
* of it.
*/
private HttpClient.ResponseReceiver<?> send(final ClientInvocation resteasyReq) {
final Optional<Object> reqPayload = Optional.ofNullable(resteasyReq.getEntity());

final HttpClient.RequestSender requestSender = httpClient.headers(headers -> addHeaders(resteasyReq, headers))
.request(HttpMethod.valueOf(resteasyReq.getMethod()))
.uri(resteasyReq.getUri().toString());

return reqPayload.<HttpClient.ResponseReceiver<?>>map(ignore -> {
return requestSender.send((reactorReq, outbound) -> {
final ByteBufOutputStream byteBufOutputStream =
new ByteBufOutputStream(outbound.alloc().buffer());

/* Replacing the ClientRequestHeaders with TrackingClientRequestHeaders
to track the changes to headers by the sendRequestBody method. */
resteasyReq.setHeaders(
new TrackingClientRequestHeaders(
resteasyReq.getClientConfiguration(),
resteasyReq.getHeaders().getHeaders()
)
);

try {
sendRequestBody(resteasyReq, byteBufOutputStream);
} catch (final IOException ioe) {
return Mono.error(ioe);
}

// Updating the HttpClientRequest with the headers modified by the sendRequestBody method.
final TrackingMap<?> trackingMap = (TrackingMap<?>) resteasyReq.getHeaders().getHeaders();
trackingMap.getAddedOrUpdatedKeys()
.forEach(key -> updateHeader(
key,
resteasyReq.getHeaders().getHeaders(),
reactorReq.requestHeaders()
)
);
trackingMap.getRemovedKeys()
.forEach(reactorReq.requestHeaders()::remove);

final int length = byteBufOutputStream.writtenBytes();
reactorReq.header(CONTENT_LENGTH, Integer.toString(length));

if (log.isDebugEnabled() &&
isContentLengthInvalid(
resteasyReq.getHeaders().getHeader(CONTENT_LENGTH), length)) {

log.debug("The request's Content-Length header is replaced " +
" by the size of the byte array computed from the request entity.");
}

return outbound.send(Mono.defer(() -> Mono.just(byteBufOutputStream.buffer())));
});
}).orElse(requestSender);
}

private static void addHeaders(final ClientInvocation resteasyReq, final HttpHeaders reactorHeaders) {
final ClientRequestHeaders resteasyHeaders = resteasyReq.getHeaders();
resteasyHeaders.getHeaders().entrySet().forEach(entry -> {
final String key = entry.getKey();
final List<Object> valueList = entry.getValue();
valueList.forEach(value -> reactorHeaders.add(key, value != null ? value : ""));
});
}

private static void updateHeader(final String key,
final CaseInsensitiveMap<Object> headers,
final HttpHeaders reactorHeaders) {
List<Object> valueList = headers.get(key);
reactorHeaders.set(key, valueList);
}

@Override
public <T> Mono<T> fromCompletionStage(final CompletionStage<T> cs) {
return Mono.fromCompletionStage(() -> cs);
Expand Down Expand Up @@ -235,10 +283,10 @@ public <K> CompletableFuture<K> submit(final ClientInvocation request,
return submitRx(request, buffered, extractor).toFuture();
}

private static boolean isContentLengthInvalid(final String headerValue, final byte[] payload) {
private static boolean isContentLengthInvalid(final String headerValue, final int length) {

try {
return headerValue != null && Long.parseLong(headerValue) != payload.length;
return headerValue != null && Long.parseLong(headerValue) != length;
} catch (Exception e) {
log.warn("Problem parsing the Content-Length header value.", e);
}
Expand Down Expand Up @@ -299,19 +347,24 @@ static RuntimeException clientException(Throwable ex, Response clientResponse) {
return ret;
}

private static byte[] requestContent(ClientInvocation request)
{
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
request.getDelegatingOutputStream().setDelegate(baos);
try {
request.writeRequestBody(request.getEntityStream());
baos.close();
return baos.toByteArray();
} catch (IOException e) {
throw new RuntimeException("Failed to write the request body!", e);
private static void sendRequestBody(final ClientInvocation req, final ByteBufOutputStream out) throws IOException {
req.getDelegatingOutputStream().setDelegate(out);

if (ResteasyContext.getContextData(Providers.class) == null) {
try (ResteasyContext.CloseableContext cc = pushProvidersContext(req)) {
req.writeRequestBody(req.getEntityStream());
}
} else {
req.writeRequestBody(req.getEntityStream());
}
}

private static ResteasyContext.CloseableContext pushProvidersContext(final ClientInvocation req) {
ResteasyContext.CloseableContext ret = ResteasyContext.addCloseableContextDataLevel();
ResteasyContext.pushContext(Providers.class, req.getClientConfiguration());
return ret;
}

private ClientResponse toRestEasyResponse(final ClientConfiguration clientConfiguration,
final HttpClientResponse reactorNettyResponse,
final InputStream inputStream) {
Expand Down
@@ -0,0 +1,19 @@
package org.jboss.resteasy.client.jaxrs.internal;

import org.jboss.resteasy.util.CaseInsensitiveMap;

/**
* An extension of ClientRequestHeaders that helps decorate the headers with a TrackingMap.
*/
public class TrackingClientRequestHeaders extends ClientRequestHeaders {

public TrackingClientRequestHeaders(final ClientConfiguration configuration, final CaseInsensitiveMap<Object> headers) {
super(configuration);
this.headers = new TrackingMap<>(headers);
}

@Override
public TrackingMap<Object> getHeaders() {
return (TrackingMap<Object>) super.getHeaders();
}
}

0 comments on commit a8e8042

Please sign in to comment.