Skip to content

Commit

Permalink
provides rollback to CompositeByteBuf usage (#750)
Browse files Browse the repository at this point in the history
* provides rollback to CompositeByteBuf usage

Signed-off-by: Oleh Dokuka <shadowgun@i.ua>

* fixes test

Signed-off-by: Oleh Dokuka <shadowgun@i.ua>
  • Loading branch information
OlegDokuka committed Mar 22, 2020
1 parent 2f71f73 commit 0cc46d0
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 13 deletions.
Expand Up @@ -3,7 +3,6 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.rsocket.buffer.TupleByteBuf;

class DataAndMetadataFlyweight {
public static final int FRAME_LENGTH_MASK = 0xFFFFFF;
Expand Down Expand Up @@ -33,19 +32,19 @@ private static int decodeLength(final ByteBuf byteBuf) {

static ByteBuf encodeOnlyMetadata(
ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata) {
return TupleByteBuf.of(allocator, header, metadata);
return allocator.compositeBuffer(2).addComponents(true, header, metadata);
}

static ByteBuf encodeOnlyData(ByteBufAllocator allocator, final ByteBuf header, ByteBuf data) {
return TupleByteBuf.of(allocator, header, data);
return allocator.compositeBuffer(2).addComponents(true, header, data);
}

static ByteBuf encode(
ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata, ByteBuf data) {

int length = metadata.readableBytes();
encodeLength(header, length);
return TupleByteBuf.of(allocator, header, metadata, data);
return allocator.compositeBuffer(3).addComponents(true, header, metadata, data);
}

static ByteBuf metadataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) {
Expand Down
Expand Up @@ -2,7 +2,6 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.rsocket.buffer.TupleByteBuf;

/**
* Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections
Expand Down Expand Up @@ -35,7 +34,7 @@ private static int decodeLength(final ByteBuf byteBuf) {
public static ByteBuf encode(ByteBufAllocator allocator, int length, ByteBuf frame) {
ByteBuf buffer = allocator.buffer();
encodeLength(buffer, length);
return TupleByteBuf.of(allocator, buffer, frame);
return allocator.compositeBuffer(2).addComponents(true, buffer, frame);
}

public static int length(ByteBuf byteBuf) {
Expand Down
Expand Up @@ -5,7 +5,6 @@
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.util.CharsetUtil;
import io.rsocket.buffer.TupleByteBuf;
import io.rsocket.util.CharByteBufUtil;

public class AuthMetadataFlyweight {
Expand Down Expand Up @@ -49,7 +48,7 @@ public static ByteBuf encodeMetadata(

ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength);

return TupleByteBuf.of(allocator, headerBuffer, metadata);
return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata);
}

/**
Expand All @@ -76,7 +75,7 @@ public static ByteBuf encodeMetadata(
.buffer(capacity, capacity)
.writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK);

return TupleByteBuf.of(allocator, headerBuffer, metadata);
return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata);
}

/**
Expand Down
Expand Up @@ -182,7 +182,10 @@ public static char[] readUtf8(ByteBuf byteBuf, int length) {
char[] ca = new char[en];

CharBuffer charBuffer = CharBuffer.wrap(ca);
ByteBuffer byteBuffer = byteBuf.internalNioBuffer(byteBuf.readerIndex(), length);
ByteBuffer byteBuffer =
byteBuf.nioBufferCount() == 1
? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length)
: byteBuf.nioBuffer(byteBuf.readerIndex(), length);
byteBuffer.mark();
try {
CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true);
Expand Down
76 changes: 73 additions & 3 deletions rsocket-test/src/main/java/io/rsocket/test/TransportTest.java
Expand Up @@ -23,10 +23,14 @@
import io.rsocket.transport.ClientTransport;
import io.rsocket.transport.ServerTransport;
import io.rsocket.util.DefaultPayload;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.time.Duration;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
Expand All @@ -38,6 +42,25 @@

public interface TransportTest {

String MOCK_DATA = "test-data";
String MOCK_METADATA = "metadata";
String LARGE_DATA = read("words.shakespeare.txt.gz");
Payload LARGE_PAYLOAD = DefaultPayload.create(LARGE_DATA, LARGE_DATA);

static String read(String resourceName) {

try (BufferedReader br =
new BufferedReader(
new InputStreamReader(
new GZIPInputStream(
TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) {

return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r"));
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

@AfterEach
default void close() {
getTransportPair().dispose();
Expand All @@ -54,12 +77,12 @@ default Payload createTestPayload(int metadataPresent) {
metadata1 = "";
break;
default:
metadata1 = "metadata";
metadata1 = MOCK_METADATA;
break;
}
String metadata = metadata1;

return DefaultPayload.create("test-data", metadata);
return DefaultPayload.create(MOCK_DATA, metadata);
}

@DisplayName("makes 10 fireAndForget requests")
Expand All @@ -73,6 +96,17 @@ default void fireAndForget10() {
.verify(getTimeout());
}

@DisplayName("makes 10 fireAndForget with Large Payload in Requests")
@Test
default void largePayloadFireAndForget10() {
Flux.range(1, 10)
.flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD))
.as(StepVerifier::create)
.expectNextCount(0)
.expectComplete()
.verify(getTimeout());
}

default RSocket getClient() {
return getTransportPair().getClient();
}
Expand All @@ -92,6 +126,17 @@ default void metadataPush10() {
.verify(getTimeout());
}

@DisplayName("makes 10 metadataPush with Large Metadata in requests")
@Test
default void largePayloadMetadataPush10() {
Flux.range(1, 10)
.flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA)))
.as(StepVerifier::create)
.expectNextCount(0)
.expectComplete()
.verify(getTimeout());
}

@DisplayName("makes 1 requestChannel request with 0 payloads")
@Test
default void requestChannel0() {
Expand Down Expand Up @@ -127,6 +172,19 @@ default void requestChannel200_000() {
.verify(getTimeout());
}

@DisplayName("makes 1 requestChannel request with 2,000 large payloads")
@Test
default void largePayloadRequestChannel200() {
Flux<Payload> payloads = Flux.range(0, 200).map(__ -> LARGE_PAYLOAD);

getClient()
.requestChannel(payloads)
.as(StepVerifier::create)
.expectNextCount(200)
.expectComplete()
.verify(getTimeout());
}

@DisplayName("makes 1 requestChannel request with 20,000 payloads")
@Test
default void requestChannel20_000() {
Expand Down Expand Up @@ -223,6 +281,17 @@ default void requestResponse100() {
.verify(getTimeout());
}

@DisplayName("makes 100 requestResponse requests")
@Test
default void largePayloadRequestResponse100() {
Flux.range(1, 100)
.flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8))
.as(StepVerifier::create)
.expectNextCount(100)
.expectComplete()
.verify(getTimeout());
}

@DisplayName("makes 10,000 requestResponse requests")
@Test
default void requestResponse10_000() {
Expand Down Expand Up @@ -283,7 +352,7 @@ default void assertPayload(Payload p) {
}

default void assertChannelPayload(Payload p) {
if (!"test-data".equals(p.getDataUtf8()) || !"metadata".equals(p.getMetadataUtf8())) {
if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) {
throw new IllegalStateException("Unexpected payload");
}
}
Expand Down Expand Up @@ -312,6 +381,7 @@ public TransportPair(

client =
RSocketFactory.connect()
.keepAlive(Duration.ZERO, Duration.ZERO, 1)
.transport(clientTransportSupplier.apply(address, server))
.start()
.doOnError(Throwable::printStackTrace)
Expand Down
Binary file not shown.
@@ -0,0 +1,55 @@
package io.rsocket.transport.netty;

import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.rsocket.test.TransportTest;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.server.TcpServerTransport;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;
import java.time.Duration;
import reactor.core.Exceptions;
import reactor.netty.tcp.TcpClient;
import reactor.netty.tcp.TcpServer;

public class TcpSecureTransportTest implements TransportTest {
private final TransportPair transportPair =
new TransportPair<>(
() -> new InetSocketAddress("localhost", 0),
(address, server) ->
TcpClientTransport.create(
TcpClient.create()
.addressSupplier(server::address)
.secure(
ssl ->
ssl.sslContext(
SslContextBuilder.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)))),
address -> {
try {
SelfSignedCertificate ssc = new SelfSignedCertificate();
TcpServer server =
TcpServer.create()
.addressSupplier(() -> address)
.secure(
ssl ->
ssl.sslContext(
SslContextBuilder.forServer(
ssc.certificate(), ssc.privateKey())));
return TcpServerTransport.create(server);
} catch (CertificateException e) {
throw Exceptions.propagate(e);
}
});

@Override
public Duration getTimeout() {
return Duration.ofMinutes(10);
}

@Override
public TransportPair getTransportPair() {
return transportPair;
}
}

0 comments on commit 0cc46d0

Please sign in to comment.