Skip to content

Commit

Permalink
The issue is actually related to the fact that Tomcat is by default c…
Browse files Browse the repository at this point in the history
…onfigured with maxSwallowSize=2 MB,

so we need to only test with Tomcat.
  • Loading branch information
pderop committed Sep 9, 2023
1 parent 972b528 commit dfad10b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 266 deletions.
20 changes: 9 additions & 11 deletions reactor-netty-http/src/test/java/reactor/netty/TomcatServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,28 +169,27 @@ protected void service(HttpServletRequest req, HttpServletResponse resp) throws

static final class PayloadSizeServlet extends HttpServlet {

static final int MAX = 1024 * 64;
static final int MAX = 5000000;

@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws IOException {
InputStream in = req.getInputStream();
byte[] buf = new byte[4096];
int count = 0;
int n;

if ((count = req.getContentLength()) != -1 && count >= MAX) {
sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST);
}

count = 0;
while ((n = in.read(buf, 0, buf.length)) != -1) {
while ((n = in.read()) != -1) {
count += n;
if (count >= MAX) {
// By default, Tomcat is configured with maxSwallowSize=2 MB (see https://tomcat.apache.org/tomcat-9.0-doc/config/http.html)
// This means that once the 400 bad request is sent, the client will still be able to continue writing (if it is currently writing)
// up to 2 MB. So, it is very likely that the client will be blocked and it will then be able to consume the 400 bad request and
// close itself the connection.
sendResponse(resp, TOO_LARGE, HttpServletResponse.SC_BAD_REQUEST);
return;
}
}
sendResponse(resp, "Request payload size: " + count, HttpServletResponse.SC_OK);

sendResponse(resp, String.valueOf(count), HttpServletResponse.SC_OK);
}

private void sendResponse(HttpServletResponse resp, String message, int status) throws IOException {
Expand All @@ -199,8 +198,7 @@ private void sendResponse(HttpServletResponse resp, String message, int status)
resp.setHeader("Content-Type", "text/plain");
PrintWriter out = resp.getWriter();
out.print(message);
resp.flushBuffer();
out.close(); // will send last-chunk header and will close
out.flush();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand All @@ -56,8 +55,6 @@
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;

import io.netty.buffer.ByteBuf;
Expand All @@ -66,9 +63,7 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
Expand Down Expand Up @@ -99,10 +94,8 @@
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -150,29 +143,12 @@ class HttpClientTest extends BaseHttpTest {

static final Logger log = Loggers.getLogger(HttpClientTest.class);

static final byte[] PAYLOAD = String.join("", Collections.nCopies(1024 * 128, "X"))
.getBytes(Charset.defaultCharset());

static SelfSignedCertificate ssc;
static final EventExecutor executor = new DefaultEventExecutor();
static Http11SslContextSpec serverCtx11;
static Http2SslContextSpec serverCtx2;
static Http11SslContextSpec clientCtx11;
static Http2SslContextSpec clientCtx2;

@BeforeAll
static void createSelfSignedCertificate() throws CertificateException {
ssc = new SelfSignedCertificate();
serverCtx11 = Http11SslContextSpec.forServer(ssc.certificate(), ssc.privateKey())
.configure(builder -> builder.sslProvider(io.netty.handler.ssl.SslProvider.JDK));
serverCtx2 = Http2SslContextSpec.forServer(ssc.certificate(), ssc.privateKey())
.configure(builder -> builder.sslProvider(io.netty.handler.ssl.SslProvider.JDK));
clientCtx11 = Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(io.netty.handler.ssl.SslProvider.JDK));
clientCtx2 = Http2SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE)
.sslProvider(io.netty.handler.ssl.SslProvider.JDK));
}

@AfterAll
Expand All @@ -181,17 +157,6 @@ static void cleanup() throws ExecutionException, InterruptedException, TimeoutEx
.get(30, TimeUnit.SECONDS);
}

HttpServer customizeServerOptions(HttpServer httpServer, @Nullable SslProvider.ProtocolSslContextSpec ctx, HttpProtocol... protocols) {
return ctx == null ? httpServer.protocol(protocols) :
httpServer.protocol(protocols).secure(spec -> spec.sslContext(ctx)
.closeNotifyFlushTimeout(Duration.ofSeconds(10))
.closeNotifyReadTimeout(Duration.ofSeconds(10)));
}

HttpClient customizeClientOptions(HttpClient httpClient, @Nullable SslProvider.ProtocolSslContextSpec ctx, HttpProtocol... protocols) {
return ctx == null ? httpClient.protocol(protocols) : httpClient.protocol(protocols).secure(spec -> spec.sslContext(ctx));
}

@Test
void abort() {
disposableServer =
Expand Down Expand Up @@ -3361,209 +3326,4 @@ private void doTestIssue1943(HttpProtocol protocol) {
.block(Duration.ofSeconds(5));
}
}

static Stream<Arguments> issue2825Params() {
Supplier<Publisher<ByteBuf>> postMono = () -> Mono.just(Unpooled.wrappedBuffer(PAYLOAD));
Supplier<Publisher<ByteBuf>> postFlux = () -> Flux.just(Unpooled.wrappedBuffer(PAYLOAD));

return Stream.of(
Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11,
null, // no SSL Context
null, // no SSL Context
Named.of("postMono", postMono),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11,
null, // no SSL Context
null, // no SSL Context
Named.of("postFlux", postFlux),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11,
Named.of("Http11SslContextSpec", serverCtx11),
Named.of("Http11SslContextSpec", clientCtx11),
Named.of("postMono", postMono),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.HTTP11, HttpProtocol.HTTP11,
Named.of("Http11SslContextSpec", serverCtx11),
Named.of("Http11SslContextSpec", clientCtx11),
Named.of("postFlux", postFlux),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.H2C, HttpProtocol.H2C,
null, // no SSL Context
null, // no SSL Context
Named.of("postMono", postMono),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.H2C, HttpProtocol.H2C,
null, // no SSL Context
null, // no SSL Context
Named.of("postFlux", postFlux),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.H2, HttpProtocol.H2,
Named.of("Http2SslContextSpec", serverCtx2),
Named.of("Http2SslContextSpec", clientCtx2),
Named.of("postMono", postMono),
Named.of("bytes", PAYLOAD.length)),

Arguments.of(HttpProtocol.H2, HttpProtocol.H2,
Named.of("Http2SslContextSpec", serverCtx2),
Named.of("Http2SslContextSpec", clientCtx2),
Named.of("postFlux", postFlux),
Named.of("bytes", PAYLOAD.length))
);
}

@ParameterizedTest
@MethodSource("issue2825Params")
void testIssue2825(HttpProtocol serverProtocols, HttpProtocol clientProtocols,
@Nullable SslProvider.ProtocolSslContextSpec serverCtx, @Nullable SslProvider.ProtocolSslContextSpec clientCtx,
Supplier<Publisher<ByteBuf>> payload, long bytesToSend) {
int maxSize = 1024 * 64; // 400 bad request is returned if payload exceeds this limit, and the socket is then closed
AtomicInteger accum = new AtomicInteger();
String tooLargeRequest = "Request too large";
byte[] tooLargeRequestBytes = tooLargeRequest.getBytes(Charset.defaultCharset());
byte[] requestFullyReceivedBytes = "Request fully received".getBytes(Charset.defaultCharset());

HttpServer httpServer = createServer()
.wiretap(false)
.route(r -> r.post("/large-payload", (req, res) -> req.receive()
.takeUntil(buf -> {
String clen = req.requestHeaders().get("Content-Length");
if (clen != null) {
int contentLength = Integer.parseInt(clen);
accum.set(contentLength);
return contentLength >= maxSize;
}
else {
return accum.addAndGet(buf.readableBytes()) >= maxSize;
}
})
.collectList()
.flatMapMany(byteBufs -> res.status(accum.get() < maxSize ? 200 : 400)
.header("Connection", "close")
.header("Content-Type", "text/plain")
.send(Mono.just(Unpooled.wrappedBuffer(accum.get() < maxSize ?
requestFullyReceivedBytes : tooLargeRequestBytes))))));

disposableServer = customizeServerOptions(httpServer, serverCtx, serverProtocols)
.bindNow();

AtomicReference<SocketAddress> serverAddress = new AtomicReference<>();
HttpClient client = customizeClientOptions(createClient(disposableServer.port()), clientCtx, clientProtocols)
.metrics(true, ClientMetricsRecorder::reset)
.doOnConnected(conn -> serverAddress.set(conn.address()))
.disableRetry(true)
// Needed to trigger many writability change events
.doOnConnected(connection -> connection.channel().config().setOption(ChannelOption.SO_SNDBUF, 128));

StepVerifier.create(client
.wiretap(false)
.headers(hdr -> hdr.set("Content-Type", "text/plain"))
.post()
.uri("/large-payload")
.send(payload.get())
.response((r, buf) -> buf.aggregate().asString()
.zipWith(Mono.just(r))))
.expectNextMatches(tuple -> tooLargeRequest.equals(tuple.getT1())
&& tuple.getT2().status().equals(HttpResponseStatus.BAD_REQUEST))
.expectComplete()
.verify(Duration.ofSeconds(30));

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeMethod).isEqualTo("POST");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime).isNotNull();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeTime.isZero()).isFalse();
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeUri).isEqualTo("/large-payload");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentTimeRemoteAddr).isEqualTo(serverAddress.get());

assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentRemoteAddr).isEqualTo(serverAddress.get());
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentUri).isEqualTo("/large-payload");
assertThat(ClientMetricsRecorder.INSTANCE.recordDataSentBytes).isEqualTo(bytesToSend);
}

/**
* This Custom metrics recorder checks that the {@link AbstractHttpClientMetricsHandler#recordWrite(SocketAddress)} is properly invoked by
* (see {@link AbstractHttpClientMetricsHandler#channelRead(ChannelHandlerContext, Object)}) when
* an early response is received while the corresponding request it still being written.
*/
static final class ClientMetricsRecorder implements HttpClientMetricsRecorder {

static final ClientMetricsRecorder INSTANCE = new ClientMetricsRecorder();
volatile SocketAddress recordDataSentTimeRemoteAddr;
volatile String recordDataSentTimeUri;
volatile String recordDataSentTimeMethod;
volatile Duration recordDataSentTimeTime;
volatile SocketAddress recordDataSentRemoteAddr;
volatile String recordDataSentUri;
volatile long recordDataSentBytes;

static ClientMetricsRecorder reset() {
INSTANCE.recordDataSentTimeRemoteAddr = null;
INSTANCE.recordDataSentTimeUri = null;
INSTANCE.recordDataSentTimeMethod = null;
INSTANCE.recordDataSentTimeTime = null;
INSTANCE.recordDataSentRemoteAddr = null;
INSTANCE.recordDataSentUri = null;
INSTANCE.recordDataSentBytes = -1;
return INSTANCE;
}

@Override
public void recordDataReceived(SocketAddress remoteAddress, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, long bytes) {
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress) {
}

@Override
public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) {
}

@Override
public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) {
}

@Override
public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) {
this.recordDataSentRemoteAddr = remoteAddress;
this.recordDataSentUri = uri;
this.recordDataSentBytes = bytes;
}

@Override
public void incrementErrorsCount(SocketAddress remoteAddress, String uri) {
}

@Override
public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) {
}

@Override
public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) {
this.recordDataSentTimeRemoteAddr = remoteAddress;
this.recordDataSentTimeUri = uri;
this.recordDataSentTimeMethod = method;
this.recordDataSentTimeTime = time;
}

@Override
public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) {
}
}
}
Loading

0 comments on commit dfad10b

Please sign in to comment.