Skip to content

Commit

Permalink
Send TDS fragments as individual TCP packets
Browse files Browse the repository at this point in the history
We now apply fragmentation also on the TCP level by representing each TDS message chunk in its own buffer and thus each TDS packet gets written as TCP packet. Representing fragmentation on the transport level is neccessary as Azure SQL disconnects clients if a single TCP packet exceeds the TDS packet size.

[closes #142]
  • Loading branch information
mp911de committed May 4, 2020
1 parent 4ed9988 commit f385598
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 41 deletions.
66 changes: 36 additions & 30 deletions src/main/java/io/r2dbc/mssql/client/TdsEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.PromiseCombiner;
import io.r2dbc.mssql.message.header.Header;
import io.r2dbc.mssql.message.header.HeaderOptions;
import io.r2dbc.mssql.message.header.PacketIdProvider;
Expand Down Expand Up @@ -256,51 +257,56 @@ private void writeSingleMessage(ChannelHandlerContext ctx, ChannelPromise promis

private void writeChunkedMessage(ChannelHandlerContext ctx, ChannelPromise promise, ByteBuf body, HeaderOptions headerOptions, boolean lastLogicalPacket) {

ByteBuf chunked = body.alloc().buffer(estimateChunkedSize(getBytesToWrite(body.readableBytes())));
PromiseCombiner combiner = new PromiseCombiner(ctx.executor());

while (body.readableBytes() > 0) {
try {
while (body.readableBytes() > 0) {

ByteBuf chunk;
if (this.lastChunkRemainder != null) {
ByteBuf chunk = body.alloc().buffer(estimateChunkSize(getBytesToWrite(body.readableBytes())));

int combinedSize = this.lastChunkRemainder.readableBytes() + body.readableBytes();
HeaderOptions optionsToUse = isLastTransportPacket(combinedSize, lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions);
Header.encode(chunked, optionsToUse, this.packetSize, this.packetIdProvider);
if (this.lastChunkRemainder != null) {

int actualBodyReadableBytes = this.packetSize - Header.LENGTH - this.lastChunkRemainder.readableBytes();
chunked.writeBytes(this.lastChunkRemainder);
chunked.writeBytes(body, actualBodyReadableBytes);
int combinedSize = this.lastChunkRemainder.readableBytes() + body.readableBytes();
HeaderOptions optionsToUse = isLastTransportPacket(combinedSize, lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions);
Header.encode(chunk, optionsToUse, this.packetSize, this.packetIdProvider);

this.lastChunkRemainder.release();
this.lastChunkRemainder = null;
int actualBodyReadableBytes = this.packetSize - Header.LENGTH - this.lastChunkRemainder.readableBytes();
chunk.writeBytes(this.lastChunkRemainder);
chunk.writeBytes(body, actualBodyReadableBytes);

} else {
this.lastChunkRemainder.release();
this.lastChunkRemainder = null;

if (!lastLogicalPacket && !requiresChunking(body.readableBytes())) {
} else {

// Prevent partial packets/buffer underrun if not the last packet.
this.lastChunkRemainder = body.alloc().compositeBuffer();
this.lastChunkRemainder.addComponent(true, body.retain());
break;
}
if (!lastLogicalPacket && !requiresChunking(body.readableBytes())) {

HeaderOptions optionsToUse = isLastTransportPacket(body.readableBytes(), lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions);
// Prevent partial packets/buffer underrun if not the last packet.
this.lastChunkRemainder = body.alloc().compositeBuffer();
this.lastChunkRemainder.addComponent(true, body.retain());
break;
}

chunk = body.readSlice(getEffectiveChunkSizeWithoutHeader(body.readableBytes()));
HeaderOptions optionsToUse = isLastTransportPacket(body.readableBytes(), lastLogicalPacket) ? getLastHeader(headerOptions) : getChunkedHeaderOptions(headerOptions);

Header.encode(chunked, optionsToUse, Header.LENGTH + chunk.readableBytes(), this.packetIdProvider);
chunked.writeBytes(chunk);
}
}
int byteCount = getEffectiveChunkSizeWithoutHeader(body.readableBytes());
Header.encode(chunk, optionsToUse, Header.LENGTH + byteCount, this.packetIdProvider);

ctx.write(chunked, promise);
}
chunk.writeBytes(body, byteCount);
}

int estimateChunkedSize(int readableBytes) {
combiner.add(ctx.write(chunk, ctx.newPromise()));
}

int netPacketSize = this.packetSize + 1 - Header.LENGTH;
combiner.finish(promise);
} catch (RuntimeException e) {
promise.tryFailure(e);
throw e;
}
}

return readableBytes + (((readableBytes / netPacketSize) + 1) * Header.LENGTH);
int estimateChunkSize(int readableBytes) {
return Math.min(readableBytes + Header.LENGTH, this.packetSize);
}

private boolean requiresChunking(int readableBytes) {
Expand Down
13 changes: 9 additions & 4 deletions src/test/java/io/r2dbc/mssql/RpcBlobUnitTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.r2dbc.mssql.client.TdsEncoder;
import io.r2dbc.mssql.codec.DefaultCodecs;
import io.r2dbc.mssql.codec.PlpEncoded;
Expand Down Expand Up @@ -72,18 +74,22 @@ void shouldEncodeChunkedStream() {
TdsEncoder encoder = new TdsEncoder(PacketIdProvider.just(1), 8000);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT);
when(ctx.executor()).thenReturn(ImmediateEventExecutor.INSTANCE);
ChannelPromise promise = mock(ChannelPromise.class);
when(ctx.newPromise()).thenReturn(promise);
when(ctx.write(any(), any(ChannelPromise.class))).then(invocationOnMock -> {

ByteBuf buf = invocationOnMock.getArgument(0);


int toRead = buf.readableBytes();
byte[] bytes = new byte[toRead];
buf.readBytes(bytes);
buf.release();

return null;
if (buf != Unpooled.EMPTY_BUFFER) {
buf.release();
}

return invocationOnMock.getArgument(1);
});

Flux.from(request.encode(ByteBufAllocator.DEFAULT, 8000))
Expand All @@ -92,6 +98,5 @@ void shouldEncodeChunkedStream() {
.as(StepVerifier::create)
.expectNextCount(32)
.verifyComplete();

}
}
18 changes: 12 additions & 6 deletions src/test/java/io/r2dbc/mssql/client/TdsEncoderUnitTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,13 @@ void shouldEncodeAndSplitContextualTdsFragment() {

channel.writeOutbound(fragment);

// Chunk 1
assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> {

// Chunk 1
encodeExpectation(buffer, StatusBit.NORMAL, 0x0c, "foob");
});

// Chunk 2
// Chunk 2
assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> {
encodeExpectation(buffer, StatusBit.EOM, 0x0a, "ar");
});
}
Expand Down Expand Up @@ -288,6 +289,11 @@ void shouldChunkMessagesLargeLargeLarge() {
assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> {

encodeExpectation(buffer, StatusBit.NORMAL, 0x0c, "ijkl");
});

// Chunk 4
assertThat(channel).outbound().hasByteBufMessage().isEncodedAs(buffer -> {

encodeExpectation(buffer, StatusBit.EOM, 0x0c, "mnop");
});
}
Expand All @@ -297,9 +303,9 @@ void shouldEstimateTdsPacketSize() {

TdsEncoder encoder = new TdsEncoder(PacketIdProvider.just(42), 12);

Assertions.assertThat(encoder.estimateChunkedSize(1)).isEqualTo(9);
Assertions.assertThat(encoder.estimateChunkedSize(4)).isEqualTo(12);
Assertions.assertThat(encoder.estimateChunkedSize(5)).isEqualTo(21);
Assertions.assertThat(encoder.estimateChunkSize(1)).isEqualTo(9);
Assertions.assertThat(encoder.estimateChunkSize(4)).isEqualTo(12);
Assertions.assertThat(encoder.estimateChunkSize(5)).isEqualTo(12);
}

private static void encodeExpectation(ByteBuf buffer, StatusBit bit, int length, String content) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private MessagesAssert(String direction, Queue<Object> actual) {
public EncodedAssert hasByteBufMessage() {

isNotNull();
Object poll = actual.poll();
Object poll = this.actual.poll();

Assertions.assertThat(poll).describedAs(this.direction + " message").isNotNull().isInstanceOf(ByteBuf.class);

Expand Down

0 comments on commit f385598

Please sign in to comment.