From f38559826cdf6ab6394f5a13c972a4cfa5b440bf Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 4 May 2020 15:24:10 +0200 Subject: [PATCH] Send TDS fragments as individual TCP packets We now apply fragmentation also on the TCP level by representing each TDS message chunk in its own buffer and thus each TDS packet gets written as TCP packet. Representing fragmentation on the transport level is neccessary as Azure SQL disconnects clients if a single TCP packet exceeds the TDS packet size. [closes #142] --- .../io/r2dbc/mssql/client/TdsEncoder.java | 66 ++++++++++--------- .../java/io/r2dbc/mssql/RpcBlobUnitTests.java | 13 ++-- .../mssql/client/TdsEncoderUnitTests.java | 18 +++-- .../mssql/util/EmbeddedChannelAssert.java | 2 +- 4 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/main/java/io/r2dbc/mssql/client/TdsEncoder.java b/src/main/java/io/r2dbc/mssql/client/TdsEncoder.java index 7720b2d2..77b05489 100644 --- a/src/main/java/io/r2dbc/mssql/client/TdsEncoder.java +++ b/src/main/java/io/r2dbc/mssql/client/TdsEncoder.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.PromiseCombiner; import io.r2dbc.mssql.message.header.Header; import io.r2dbc.mssql.message.header.HeaderOptions; import io.r2dbc.mssql.message.header.PacketIdProvider; @@ -256,51 +257,56 @@ private void writeSingleMessage(ChannelHandlerContext ctx, ChannelPromise promis private void writeChunkedMessage(ChannelHandlerContext ctx, ChannelPromise promise, ByteBuf body, HeaderOptions headerOptions, boolean lastLogicalPacket) { - ByteBuf chunked = body.alloc().buffer(estimateChunkedSize(getBytesToWrite(body.readableBytes()))); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); - while (body.readableBytes() > 0) { + try { + while (body.readableBytes() > 0) { - ByteBuf chunk; - if (this.lastChunkRemainder != null) { + ByteBuf chunk = body.alloc().buffer(estimateChunkSize(getBytesToWrite(body.readableBytes()))); - int combinedSize = this.lastChunkRemainder.readableBytes() + body.readableBytes(); - HeaderOptions optionsToUse = isLastTransportPacket(combinedSize, lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions); - Header.encode(chunked, optionsToUse, this.packetSize, this.packetIdProvider); + if (this.lastChunkRemainder != null) { - int actualBodyReadableBytes = this.packetSize - Header.LENGTH - this.lastChunkRemainder.readableBytes(); - chunked.writeBytes(this.lastChunkRemainder); - chunked.writeBytes(body, actualBodyReadableBytes); + int combinedSize = this.lastChunkRemainder.readableBytes() + body.readableBytes(); + HeaderOptions optionsToUse = isLastTransportPacket(combinedSize, lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions); + Header.encode(chunk, optionsToUse, this.packetSize, this.packetIdProvider); - this.lastChunkRemainder.release(); - this.lastChunkRemainder = null; + int actualBodyReadableBytes = this.packetSize - Header.LENGTH - this.lastChunkRemainder.readableBytes(); + chunk.writeBytes(this.lastChunkRemainder); + chunk.writeBytes(body, actualBodyReadableBytes); - } else { + this.lastChunkRemainder.release(); + this.lastChunkRemainder = null; - if (!lastLogicalPacket && !requiresChunking(body.readableBytes())) { + } else { - // Prevent partial packets/buffer underrun if not the last packet. - this.lastChunkRemainder = body.alloc().compositeBuffer(); - this.lastChunkRemainder.addComponent(true, body.retain()); - break; - } + if (!lastLogicalPacket && !requiresChunking(body.readableBytes())) { - HeaderOptions optionsToUse = isLastTransportPacket(body.readableBytes(), lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions); + // Prevent partial packets/buffer underrun if not the last packet. + this.lastChunkRemainder = body.alloc().compositeBuffer(); + this.lastChunkRemainder.addComponent(true, body.retain()); + break; + } - chunk = body.readSlice(getEffectiveChunkSizeWithoutHeader(body.readableBytes())); + HeaderOptions optionsToUse = isLastTransportPacket(body.readableBytes(), lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions); - Header.encode(chunked, optionsToUse, Header.LENGTH + chunk.readableBytes(), this.packetIdProvider); - chunked.writeBytes(chunk); - } - } + int byteCount = getEffectiveChunkSizeWithoutHeader(body.readableBytes()); + Header.encode(chunk, optionsToUse, Header.LENGTH + byteCount, this.packetIdProvider); - ctx.write(chunked, promise); - } + chunk.writeBytes(body, byteCount); + } - int estimateChunkedSize(int readableBytes) { + combiner.add(ctx.write(chunk, ctx.newPromise())); + } - int netPacketSize = this.packetSize + 1 - Header.LENGTH; + combiner.finish(promise); + } catch (RuntimeException e) { + promise.tryFailure(e); + throw e; + } + } - return readableBytes + (((readableBytes / netPacketSize) + 1) * Header.LENGTH); + int estimateChunkSize(int readableBytes) { + return Math.min(readableBytes + Header.LENGTH, this.packetSize); } private boolean requiresChunking(int readableBytes) { diff --git a/src/test/java/io/r2dbc/mssql/RpcBlobUnitTests.java b/src/test/java/io/r2dbc/mssql/RpcBlobUnitTests.java index f1bed604..ad13dca5 100644 --- a/src/test/java/io/r2dbc/mssql/RpcBlobUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/RpcBlobUnitTests.java @@ -18,8 +18,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.ImmediateEventExecutor; import io.r2dbc.mssql.client.TdsEncoder; import io.r2dbc.mssql.codec.DefaultCodecs; import io.r2dbc.mssql.codec.PlpEncoded; @@ -72,18 +74,22 @@ void shouldEncodeChunkedStream() { TdsEncoder encoder = new TdsEncoder(PacketIdProvider.just(1), 8000); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + when(ctx.executor()).thenReturn(ImmediateEventExecutor.INSTANCE); ChannelPromise promise = mock(ChannelPromise.class); + when(ctx.newPromise()).thenReturn(promise); when(ctx.write(any(), any(ChannelPromise.class))).then(invocationOnMock -> { ByteBuf buf = invocationOnMock.getArgument(0); - int toRead = buf.readableBytes(); byte[] bytes = new byte[toRead]; buf.readBytes(bytes); - buf.release(); - return null; + if (buf != Unpooled.EMPTY_BUFFER) { + buf.release(); + } + + return invocationOnMock.getArgument(1); }); Flux.from(request.encode(ByteBufAllocator.DEFAULT, 8000)) @@ -92,6 +98,5 @@ void shouldEncodeChunkedStream() { .as(StepVerifier::create) .expectNextCount(32) .verifyComplete(); - } } diff --git a/src/test/java/io/r2dbc/mssql/client/TdsEncoderUnitTests.java b/src/test/java/io/r2dbc/mssql/client/TdsEncoderUnitTests.java index 6f58f83c..872079c2 100644 --- a/src/test/java/io/r2dbc/mssql/client/TdsEncoderUnitTests.java +++ b/src/test/java/io/r2dbc/mssql/client/TdsEncoderUnitTests.java @@ -128,12 +128,13 @@ void shouldEncodeAndSplitContextualTdsFragment() { channel.writeOutbound(fragment); + // Chunk 1 assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> { - - // Chunk 1 encodeExpectation(buffer, StatusBit.NORMAL, 0x0c, "foob"); + }); - // Chunk 2 + // Chunk 2 + assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> { encodeExpectation(buffer, StatusBit.EOM, 0x0a, "ar"); }); } @@ -288,6 +289,11 @@ void shouldChunkMessagesLargeLargeLarge() { assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> { encodeExpectation(buffer, StatusBit.NORMAL, 0x0c, "ijkl"); + }); + + // Chunk 4 + assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> { + encodeExpectation(buffer, StatusBit.EOM, 0x0c, "mnop"); }); } @@ -297,9 +303,9 @@ void shouldEstimateTdsPacketSize() { TdsEncoder encoder = new TdsEncoder(PacketIdProvider.just(42), 12); - Assertions.assertThat(encoder.estimateChunkedSize(1)).isEqualTo(9); - Assertions.assertThat(encoder.estimateChunkedSize(4)).isEqualTo(12); - Assertions.assertThat(encoder.estimateChunkedSize(5)).isEqualTo(21); + Assertions.assertThat(encoder.estimateChunkSize(1)).isEqualTo(9); + Assertions.assertThat(encoder.estimateChunkSize(4)).isEqualTo(12); + Assertions.assertThat(encoder.estimateChunkSize(5)).isEqualTo(12); } private static void encodeExpectation(ByteBuf buffer, StatusBit bit, int length, String content) { diff --git a/src/test/java/io/r2dbc/mssql/util/EmbeddedChannelAssert.java b/src/test/java/io/r2dbc/mssql/util/EmbeddedChannelAssert.java index 7ba69232..b605f74a 100644 --- a/src/test/java/io/r2dbc/mssql/util/EmbeddedChannelAssert.java +++ b/src/test/java/io/r2dbc/mssql/util/EmbeddedChannelAssert.java @@ -75,7 +75,7 @@ private MessagesAssert(String direction, Queue actual) { public EncodedAssert hasByteBufMessage() { isNotNull(); - Object poll = actual.poll(); + Object poll = this.actual.poll(); Assertions.assertThat(poll).describedAs(this.direction + " message").isNotNull().isInstanceOf(ByteBuf.class);