diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerOperations.java b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerOperations.java index 9bda3c1f1f..a8dc61603e 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerOperations.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpServerOperations.java @@ -938,8 +938,11 @@ public void operationComplete(io.netty.util.concurrent.Future futu } } - discard(); + terminateInternal(); + } + void terminateInternal() { + discard(); terminate(); } diff --git a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpTrafficHandler.java b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpTrafficHandler.java index 07ba9a170f..4eee4376b1 100644 --- a/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpTrafficHandler.java +++ b/reactor-netty-http/src/main/java/reactor/netty/http/server/HttpTrafficHandler.java @@ -46,6 +46,7 @@ import reactor.netty.Connection; import reactor.netty.ConnectionObserver; import reactor.netty.ReactorNetty; +import reactor.netty.channel.ChannelOperations; import reactor.netty.http.logging.HttpMessageArgProviderFactory; import reactor.netty.http.logging.HttpMessageLogFactory; import reactor.util.annotation.Nullable; @@ -67,6 +68,9 @@ final class HttpTrafficHandler extends ChannelDuplexHandler implements Runnable static final HttpVersion H2 = HttpVersion.valueOf("HTTP/2.0"); + static final boolean LAST_FLUSH_WHEN_NO_READ = Boolean.parseBoolean( + System.getProperty("reactor.netty.http.server.lastFlushWhenNoRead", "false")); + final BiPredicate compress; final ServerCookieDecoder cookieDecoder; final ServerCookieEncoder cookieEncoder; @@ -96,6 +100,10 @@ final class HttpTrafficHandler extends ChannelDuplexHandler implements Runnable Boolean secure; + boolean read; + boolean needsFlush; + boolean finalizingResponse; + HttpTrafficHandler( @Nullable BiPredicate compress, ServerCookieDecoder decoder, @@ -142,6 +150,7 @@ public void channelActive(ChannelHandlerContext ctx) { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { + read = true; if (secure == null) { secure = ctx.channel().pipeline().get(SslHandler.class) != null; } @@ -152,6 +161,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } // read message and track if it was keepAlive if (msg instanceof HttpRequest) { + finalizingResponse = false; + if (idleTimeout != null) { IdleTimeoutHandler.removeIdleTimeoutHandler(ctx.pipeline()); } @@ -196,6 +207,16 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { else { overflow = false; + if (LAST_FLUSH_WHEN_NO_READ) { + ChannelOperations ops = ChannelOperations.get(ctx.channel()); + if (ops instanceof HttpServerOperations) { + if (HttpServerOperations.log.isDebugEnabled()) { + HttpServerOperations.log.debug(format(ctx.channel(), "Last HTTP packet was sent, terminating the channel")); + } + ((HttpServerOperations) ops).terminateInternal(); + } + } + DecoderResult decoderResult = request.decoderResult(); if (decoderResult.isFailure()) { sendDecodingFailures(decoderResult.cause(), msg); @@ -285,6 +306,38 @@ else if (overflow) { ctx.fireChannelRead(msg); } + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + endReadAndFlush(); + ctx.fireChannelReadComplete(); + } + + void endReadAndFlush() { + if (read) { + read = false; + if (LAST_FLUSH_WHEN_NO_READ && needsFlush) { + needsFlush = false; + ctx.flush(); + } + } + } + + @Override + public void flush(ChannelHandlerContext ctx) { + if (LAST_FLUSH_WHEN_NO_READ && finalizingResponse) { + if (needsFlush || !ctx.channel().isWritable()) { + needsFlush = false; + ctx.flush(); + } + else { + needsFlush = true; + } + } + else { + ctx.flush(); + } + } + void sendDecodingFailures(Throwable t, Object msg) { sendDecodingFailures(t, msg, null, null); } @@ -331,6 +384,12 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) } } if (msg instanceof LastHttpContent) { + finalizingResponse = true; + + if (LAST_FLUSH_WHEN_NO_READ) { + needsFlush = !read; + } + if (!shouldKeepAlive()) { if (HttpServerOperations.log.isDebugEnabled()) { HttpServerOperations.log.debug(format(ctx.channel(), "Detected non persistent http " + @@ -405,6 +464,18 @@ public void run() { HttpRequestHolder holder = (HttpRequestHolder) next; nextRequest = holder.request; + finalizingResponse = false; + + if (LAST_FLUSH_WHEN_NO_READ) { + ChannelOperations ops = ChannelOperations.get(ctx.channel()); + if (ops instanceof HttpServerOperations) { + if (HttpServerOperations.log.isDebugEnabled()) { + HttpServerOperations.log.debug(format(ctx.channel(), "Last HTTP packet was sent, terminating the channel")); + } + ((HttpServerOperations) ops).terminateInternal(); + } + } + DecoderResult decoderResult = nextRequest.decoderResult(); if (decoderResult.isFailure()) { sendDecodingFailures(decoderResult.cause(), nextRequest, holder.timestamp, null); diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerOutboundCompleteTest.java b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerOutboundCompleteTest.java index 9d7ed779dd..7838d1a3c4 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerOutboundCompleteTest.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerOutboundCompleteTest.java @@ -25,7 +25,9 @@ import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.LastHttpContent; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.reactivestreams.Publisher; @@ -33,12 +35,15 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Signal; import reactor.netty.BaseHttpTest; +import reactor.netty.Connection; import reactor.netty.DisposableServer; import reactor.netty.http.HttpProtocol; import reactor.netty.http.client.HttpClient; +import reactor.netty.tcp.TcpClient; import reactor.test.StepVerifier; import reactor.util.annotation.Nullable; +import java.net.InetSocketAddress; import java.nio.charset.Charset; import java.time.Duration; import java.util.Arrays; @@ -250,6 +255,108 @@ void httpPostRespondsSend(HttpProtocol protocol) throws Exception { assertThat(recorder.onTerminateIsReceived.get()).isEqualTo(2); } + @Test + void httpPipeliningGetRespondsSendObject() throws Exception { + String oldValue = System.getProperty("reactor.netty.http.server.lastFlushWhenNoRead", "false"); + System.setProperty("reactor.netty.http.server.lastFlushWhenNoRead", "true"); + try { + CountDownLatch latch = new CountDownLatch(64); + EventsRecorder recorder = new EventsRecorder(latch); + disposableServer = createServer(recorder, HttpProtocol.HTTP11, + r -> r.get("/1", (req, res) -> res.sendObject(Unpooled.wrappedBuffer(REPEAT.getBytes(Charset.defaultCharset()))) + .then().doOnEach(recorder).doOnCancel(recorder))); + + Connection client = + TcpClient.create() + .port(disposableServer.port()) + .wiretap(true) + .connectNow(); + + int port = disposableServer.port(); + String address = HttpUtil.formatHostnameForHttp((InetSocketAddress) disposableServer.address()) + ":" + port; + String request = repeatString("GET /1 HTTP/1.1\r\nHost: " + address + "\r\n\r\n"); + client.outbound() + .sendObject(Unpooled.wrappedBuffer(request.getBytes(Charset.defaultCharset()))) + .then() + .subscribe(); + + CountDownLatch responses = new CountDownLatch(16); + client.inbound() + .receive() + .asString() + .doOnNext(s -> { + int ind = 0; + while ((ind = s.indexOf("200", ind)) != -1) { + responses.countDown(); + ind += 3; + } + }) + .subscribe(); + + assertThat(responses.await(5, TimeUnit.SECONDS)).isTrue(); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(recorder.bufferIsReleased.get()).isEqualTo(16); + assertThat(recorder.fullResponseIsSent.get()).isEqualTo(16); + assertThat(recorder.onCompleteIsReceived.get()).isEqualTo(16); + assertThat(recorder.onTerminateIsReceived.get()).isEqualTo(16); + } + finally { + System.setProperty("reactor.netty.http.server.lastFlushWhenNoRead", oldValue); + } + } + + @Test + void httpPipeliningGetRespondsSendMono() throws Exception { + String oldValue = System.getProperty("reactor.netty.http.server.lastFlushWhenNoRead", "false"); + System.setProperty("reactor.netty.http.server.lastFlushWhenNoRead", "true"); + try { + CountDownLatch latch = new CountDownLatch(64); + EventsRecorder recorder = new EventsRecorder(latch); + disposableServer = createServer(recorder, HttpProtocol.HTTP11, + r -> r.get("/1", (req, res) -> res.sendString(Mono.just(REPEAT).delayElement(Duration.ofMillis(10)) + .doOnEach(recorder).doOnCancel(recorder)))); + + Connection client = + TcpClient.create() + .port(disposableServer.port()) + .wiretap(true) + .connectNow(); + + int port = disposableServer.port(); + String address = HttpUtil.formatHostnameForHttp((InetSocketAddress) disposableServer.address()) + ":" + port; + String request = repeatString("GET /1 HTTP/1.1\r\nHost: " + address + "\r\n\r\n"); + client.outbound() + .sendObject(Unpooled.wrappedBuffer(request.getBytes(Charset.defaultCharset()))) + .then() + .subscribe(); + + CountDownLatch responses = new CountDownLatch(16); + client.inbound() + .receive() + .asString() + .doOnNext(s -> { + int ind = 0; + while ((ind = s.indexOf("200", ind)) != -1) { + responses.countDown(); + ind += 3; + } + }) + .subscribe(); + + assertThat(responses.await(5, TimeUnit.SECONDS)).isTrue(); + + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(recorder.bufferIsReleased.get()).isEqualTo(16); + assertThat(recorder.fullResponseIsSent.get()).isEqualTo(16); + assertThat(recorder.onCompleteIsReceived.get()).isEqualTo(16); + assertThat(recorder.onTerminateIsReceived.get()).isEqualTo(16); + } + finally { + System.setProperty("reactor.netty.http.server.lastFlushWhenNoRead", oldValue); + } + } + @ParameterizedTest @EnumSource(value = HttpProtocol.class, names = {"HTTP11", "H2C"}) void httpPostRespondsSendFlux(HttpProtocol protocol) throws Exception { @@ -386,6 +493,14 @@ static String createString(int length) { return new String(chars); } + static String repeatString(String s) { + StringBuilder sb = new StringBuilder(16 * s.length()); + for (int i = 0; i < 16; i++) { + sb.append(s); + } + return sb.toString(); + } + static Mono> sendGetRequest(int port, HttpProtocol protocol) { return sendRequest(port, protocol, HttpMethod.GET, 1, null); }