diff --git a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java index a73cb7f219..a5117f48ea 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java +++ b/reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java @@ -169,28 +169,27 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws static final class PayloadSizeServlet extends HttpServlet { - static final int MAX = 1024 * 64; + static final int MAX = 5000000; @Override protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException { InputStream in = req.getInputStream(); - byte[] buf = new byte[4096]; int count = 0; int n; - if ((count = req.getContentLength()) != -1 && count >= MAX) { - sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST); - } - - count = 0; - while ((n = in.read(buf, 0, buf.length)) != -1) { + while ((n = in.read()) != -1) { count += n; if (count >= MAX) { + // By default, Tomcat is configured with maxSwallowSize=2 MB (see https://tomcat.apache.org/tomcat-9.0-doc/config/http.html) + // This means that once the 400 bad request is sent, the client will still be able to continue writing (if it is currently writing) + // up to 2 MB. So, it is very likely that the client will be blocked and it will then be able to consume the 400 bad request and + // close itself the connection. sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST); return; } } - sendResponse(resp, "Request payload size: " + count, HttpServletResponse.SC_OK); + + sendResponse(resp, String.valueOf(count), HttpServletResponse.SC_OK); } private void sendResponse(HttpServletResponse resp, String message, int status) throws IOException { @@ -199,8 +198,7 @@ private void sendResponse(HttpServletResponse resp, String message, int status) resp.setHeader("Content-Type", "text/plain"); PrintWriter out = resp.getWriter(); out.print(message); - resp.flushBuffer(); - out.close(); // will send last-chunk header and will close + out.flush(); } } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientTest.java b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientTest.java index f19a0f9617..65755bb963 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientTest.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientTest.java @@ -34,7 +34,6 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -56,8 +55,6 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Stream; import javax.net.ssl.SSLException; import io.netty.buffer.ByteBuf; @@ -66,9 +63,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; -import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelId; -import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.DefaultChannelGroup; @@ -99,10 +94,8 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -150,29 +143,12 @@ class HttpClientTest extends BaseHttpTest { static final Logger log = Loggers.getLogger(HttpClientTest.class); - static final byte[] PAYLOAD = String.join("", Collections.nCopies(1024 * 128, "X")) - .getBytes(Charset.defaultCharset()); - static SelfSignedCertificate ssc; static final EventExecutor executor = new DefaultEventExecutor(); - static Http11SslContextSpec serverCtx11; - static Http2SslContextSpec serverCtx2; - static Http11SslContextSpec clientCtx11; - static Http2SslContextSpec clientCtx2; @BeforeAll static void createSelfSignedCertificate() throws CertificateException { ssc = new SelfSignedCertificate(); - serverCtx11 = Http11SslContextSpec.forServer(ssc.certificate(), ssc.privateKey()) - .configure(builder -> builder.sslProvider(io.netty.handler.ssl.SslProvider.JDK)); - serverCtx2 = Http2SslContextSpec.forServer(ssc.certificate(), ssc.privateKey()) - .configure(builder -> builder.sslProvider(io.netty.handler.ssl.SslProvider.JDK)); - clientCtx11 = Http11SslContextSpec.forClient() - .configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(io.netty.handler.ssl.SslProvider.JDK)); - clientCtx2 = Http2SslContextSpec.forClient() - .configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE) - .sslProvider(io.netty.handler.ssl.SslProvider.JDK)); } @AfterAll @@ -181,17 +157,6 @@ static void cleanup() throws ExecutionException, InterruptedException, TimeoutEx .get(30, TimeUnit.SECONDS); } - HttpServer customizeServerOptions(HttpServer httpServer, @Nullable SslProvider.ProtocolSslContextSpec ctx, HttpProtocol... protocols) { - return ctx == null ? httpServer.protocol(protocols) : - httpServer.protocol(protocols).secure(spec -> spec.sslContext(ctx) - .closeNotifyFlushTimeout(Duration.ofSeconds(10)) - .closeNotifyReadTimeout(Duration.ofSeconds(10))); - } - - HttpClient customizeClientOptions(HttpClient httpClient, @Nullable SslProvider.ProtocolSslContextSpec ctx, HttpProtocol... protocols) { - return ctx == null ? httpClient.protocol(protocols) : httpClient.protocol(protocols).secure(spec -> spec.sslContext(ctx)); - } - @Test void abort() { disposableServer = @@ -3361,209 +3326,4 @@ private void doTestIssue1943(HttpProtocol protocol) { .block(Duration.ofSeconds(5)); } } - - static Stream issue2825Params() { - Supplier> postMono = () -> Mono.just(Unpooled.wrappedBuffer(PAYLOAD)); - Supplier> postFlux = () -> Flux.just(Unpooled.wrappedBuffer(PAYLOAD)); - - return Stream.of( - Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11, - null, // no SSL Context - null, // no SSL Context - Named.of("postMono", postMono), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11, - null, // no SSL Context - null, // no SSL Context - Named.of("postFlux", postFlux), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11, - Named.of("Http11SslContextSpec", serverCtx11), - Named.of("Http11SslContextSpec", clientCtx11), - Named.of("postMono", postMono), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11, - Named.of("Http11SslContextSpec", serverCtx11), - Named.of("Http11SslContextSpec", clientCtx11), - Named.of("postFlux", postFlux), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.H2C, HttpProtocol.H2C, - null, // no SSL Context - null, // no SSL Context - Named.of("postMono", postMono), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.H2C, HttpProtocol.H2C, - null, // no SSL Context - null, // no SSL Context - Named.of("postFlux", postFlux), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.H2, HttpProtocol.H2, - Named.of("Http2SslContextSpec", serverCtx2), - Named.of("Http2SslContextSpec", clientCtx2), - Named.of("postMono", postMono), - Named.of("bytes", PAYLOAD.length)), - - Arguments.of(HttpProtocol.H2, HttpProtocol.H2, - Named.of("Http2SslContextSpec", serverCtx2), - Named.of("Http2SslContextSpec", clientCtx2), - Named.of("postFlux", postFlux), - Named.of("bytes", PAYLOAD.length)) - ); - } - - @ParameterizedTest - @MethodSource("issue2825Params") - void testIssue2825(HttpProtocol serverProtocols, HttpProtocol clientProtocols, - @Nullable SslProvider.ProtocolSslContextSpec serverCtx, @Nullable SslProvider.ProtocolSslContextSpec clientCtx, - Supplier> payload, long bytesToSend) { - int maxSize = 1024 * 64; // 400 bad request is returned if payload exceeds this limit, and the socket is then closed - AtomicInteger accum = new AtomicInteger(); - String tooLargeRequest = "Request too large"; - byte[] tooLargeRequestBytes = tooLargeRequest.getBytes(Charset.defaultCharset()); - byte[] requestFullyReceivedBytes = "Request fully received".getBytes(Charset.defaultCharset()); - - HttpServer httpServer = createServer() - .wiretap(false) - .route(r -> r.post("/large-payload", (req, res) -> req.receive() - .takeUntil(buf -> { - String clen = req.requestHeaders().get("Content-Length"); - if (clen != null) { - int contentLength = Integer.parseInt(clen); - accum.set(contentLength); - return contentLength >= maxSize; - } - else { - return accum.addAndGet(buf.readableBytes()) >= maxSize; - } - }) - .collectList() - .flatMapMany(byteBufs -> res.status(accum.get() < maxSize ? 200 : 400) - .header("Connection", "close") - .header("Content-Type", "text/plain") - .send(Mono.just(Unpooled.wrappedBuffer(accum.get() < maxSize ? - requestFullyReceivedBytes : tooLargeRequestBytes)))))); - - disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols) - .bindNow(); - - AtomicReference serverAddress = new AtomicReference<>(); - HttpClient client = customizeClientOptions(createClient(disposableServer.port()), clientCtx, clientProtocols) - .metrics(true, ClientMetricsRecorder::reset) - .doOnConnected(conn -> serverAddress.set(conn.address())) - .disableRetry(true) - // Needed to trigger many writability change events - .doOnConnected(connection -> connection.channel().config().setOption(ChannelOption.SO_SNDBUF, 128)); - - StepVerifier.create(client - .wiretap(false) - .headers(hdr -> hdr.set("Content-Type", "text/plain")) - .post() - .uri("/large-payload") - .send(payload.get()) - .response((r, buf) -> buf.aggregate().asString() - .zipWith(Mono.just(r)))) - .expectNextMatches(tuple -> tooLargeRequest.equals(tuple.getT1()) - && tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST)) - .expectComplete() - .verify(Duration.ofSeconds(30)); - - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST"); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull(); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse(); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/large-payload"); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get()); - - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get()); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/large-payload"); - assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(bytesToSend); - } - - /** - * This Custom metrics recorder checks that the {@link AbstractHttpClientMetricsHandler#recordWrite(SocketAddress)} is properly invoked by - * (see {@link AbstractHttpClientMetricsHandler#channelRead(ChannelHandlerContext, Object)}) when - * an early response is received while the corresponding request it still being written. - */ - static final class ClientMetricsRecorder implements HttpClientMetricsRecorder { - - static final ClientMetricsRecorder INSTANCE = new ClientMetricsRecorder(); - volatile SocketAddress recordDataSentTimeRemoteAddr; - volatile String recordDataSentTimeUri; - volatile String recordDataSentTimeMethod; - volatile Duration recordDataSentTimeTime; - volatile SocketAddress recordDataSentRemoteAddr; - volatile String recordDataSentUri; - volatile long recordDataSentBytes; - - static ClientMetricsRecorder reset() { - INSTANCE.recordDataSentTimeRemoteAddr = null; - INSTANCE.recordDataSentTimeUri = null; - INSTANCE.recordDataSentTimeMethod = null; - INSTANCE.recordDataSentTimeTime = null; - INSTANCE.recordDataSentRemoteAddr = null; - INSTANCE.recordDataSentUri = null; - INSTANCE.recordDataSentBytes = -1; - return INSTANCE; - } - - @Override - public void recordDataReceived(SocketAddress remoteAddress, long bytes) { - } - - @Override - public void recordDataSent(SocketAddress remoteAddress, long bytes) { - } - - @Override - public void incrementErrorsCount(SocketAddress remoteAddress) { - } - - @Override - public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { - } - - @Override - public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { - } - - @Override - public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { - } - - @Override - public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { - } - - @Override - public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { - this.recordDataSentRemoteAddr = remoteAddress; - this.recordDataSentUri = uri; - this.recordDataSentBytes = bytes; - } - - @Override - public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { - } - - @Override - public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { - } - - @Override - public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) { - this.recordDataSentTimeRemoteAddr = remoteAddress; - this.recordDataSentTimeUri = uri; - this.recordDataSentTimeMethod = method; - this.recordDataSentTimeTime = time; - } - - @Override - public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { - } - } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java index 9c1a60dda8..19b77ce490 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/client/HttpClientWithTomcatTest.java @@ -18,7 +18,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; @@ -66,14 +66,13 @@ import static org.assertj.core.api.Assertions.assertThat; import static reactor.netty.http.client.HttpClientOperations.SendForm.DEFAULT_FACTORY; -import static reactor.netty.http.client.HttpClientTest.ClientMetricsRecorder; /** * @author Violeta Georgieva */ class HttpClientWithTomcatTest { private static TomcatServer tomcat; - private static final byte[] PAYLOAD = String.join("", Collections.nCopies(1024 * 128, "X")) + private static final byte[] PAYLOAD = String.join("", Collections.nCopies((5 * 1024 * 1024) + (1024 * 1024), "X")) .getBytes(Charset.defaultCharset()); @BeforeAll @@ -347,28 +346,24 @@ static Stream testIssue2825Args() { @ParameterizedTest @MethodSource("testIssue2825Args") - void testIssue2825_Http11(@Nullable Supplier> payload, long bytesToSend) { + void testIssue2825(@Nullable Supplier> payload, long bytesToSend) { AtomicReference serverAddress = new AtomicReference<>(); HttpClient client = HttpClient.create() .port(getPort()) .wiretap(false) - .disableRetry(true) .metrics(true, ClientMetricsRecorder::reset) - // Needed to trigger many writability change events .doOnConnected(conn -> { - conn.channel().config().setOption(ChannelOption.SO_SNDBUF, 128); serverAddress.set(conn.address()); }); StepVerifier.create(client - .headers(hdr -> hdr.set("Content-Type", "text/plain")) - .post() - .uri("/payload-size") - .send(payload.get()) - .response((r, buf) -> buf.aggregate().asString() - .zipWith(Mono.just(r)))) - .expectNextMatches(tuple -> TomcatServer.TOO_LARGE.equals(tuple.getT1()) - && tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST)) + .headers(hdr -> hdr.set("Content-Type", "text/plain")) + .post() + .uri("/payload-size") + .send(payload.get()) + .response((r, buf) -> buf.aggregate().asString() + .then(Mono.just(r)))) + .expectNextMatches(r -> r.status().equals(HttpResponseStatus.BAD_REQUEST)) .expectComplete() .verify(Duration.ofSeconds(30)); @@ -390,4 +385,87 @@ private int getPort() { private String getURL() { return "http://localhost:" + tomcat.port(); } + + /** + * This Custom metrics recorder checks that the {@link AbstractHttpClientMetricsHandler#recordWrite(SocketAddress)} is properly invoked by + * (see {@link AbstractHttpClientMetricsHandler#channelRead(ChannelHandlerContext, Object)}) when + * an early response is received while the corresponding request it still being written. + */ + static final class ClientMetricsRecorder implements HttpClientMetricsRecorder { + + static final ClientMetricsRecorder INSTANCE = new ClientMetricsRecorder(); + volatile SocketAddress recordDataSentTimeRemoteAddr; + volatile String recordDataSentTimeUri; + volatile String recordDataSentTimeMethod; + volatile Duration recordDataSentTimeTime; + volatile SocketAddress recordDataSentRemoteAddr; + volatile String recordDataSentUri; + volatile long recordDataSentBytes; + + static ClientMetricsRecorder reset() { + INSTANCE.recordDataSentTimeRemoteAddr = null; + INSTANCE.recordDataSentTimeUri = null; + INSTANCE.recordDataSentTimeMethod = null; + INSTANCE.recordDataSentTimeTime = null; + INSTANCE.recordDataSentRemoteAddr = null; + INSTANCE.recordDataSentUri = null; + INSTANCE.recordDataSentBytes = -1; + return INSTANCE; + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress) { + } + + @Override + public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { + this.recordDataSentRemoteAddr = remoteAddress; + this.recordDataSentUri = uri; + this.recordDataSentBytes = bytes; + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { + } + + @Override + public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + + @Override + public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) { + this.recordDataSentTimeRemoteAddr = remoteAddress; + this.recordDataSentTimeUri = uri; + this.recordDataSentTimeMethod = method; + this.recordDataSentTimeTime = time; + } + + @Override + public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + } }