diff --git a/build.gradle b/build.gradle index 128ad7931..679b93135 100644 --- a/build.gradle +++ b/build.gradle @@ -50,6 +50,7 @@ subprojects { dependencySet(group: 'org.junit.jupiter', version: '5.1.0') { entry 'junit-jupiter-api' entry 'junit-jupiter-engine' + entry 'junit-jupiter-params' } // TODO: Remove after JUnit5 migration diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index d51da411f..0cdb6103c 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -36,6 +36,7 @@ dependencies { testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' testImplementation 'org.mockito:mockito-core' testRuntimeOnly 'ch.qos.logback:logback-classic' diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index 559a1ff59..d88cfe445 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -34,6 +34,10 @@ public static ConnectionSetupPayload create(final Frame setupFrame) { return new DefaultConnectionSetupPayload(setupFrame); } + public abstract int keepAliveInterval(); + + public abstract int keepAliveMaxLifetime(); + public abstract String metadataMimeType(); public abstract String dataMimeType(); @@ -73,6 +77,16 @@ public DefaultConnectionSetupPayload(final Frame setupFrame) { this.setupFrame = setupFrame; } + @Override + public int keepAliveInterval() { + return SetupFrameFlyweight.keepaliveInterval(setupFrame.content()); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameFlyweight.maxLifetime(setupFrame.content()); + } + @Override public String metadataMimeType() { return Setup.metadataMimeType(setupFrame); diff --git a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java new file mode 100644 index 000000000..eac560dfa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java @@ -0,0 +1,120 @@ +package io.rsocket; + +import io.netty.buffer.Unpooled; +import java.time.Duration; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.UnicastProcessor; + +abstract class KeepAliveHandler { + private final KeepAlive keepAlive; + private final UnicastProcessor sent = UnicastProcessor.create(); + private final MonoProcessor timeout = MonoProcessor.create(); + private final Flux interval; + private Disposable intervalDisposable; + private volatile long lastReceivedMillis; + + static KeepAliveHandler ofServer(KeepAlive keepAlive) { + return new KeepAliveHandler.Server(keepAlive); + } + + static KeepAliveHandler ofClient(KeepAlive keepAlive) { + return new KeepAliveHandler.Client(keepAlive); + } + + private KeepAliveHandler(KeepAlive keepAlive) { + this.keepAlive = keepAlive; + this.interval = Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod())); + } + + public void start() { + this.lastReceivedMillis = System.currentTimeMillis(); + intervalDisposable = interval.subscribe(v -> onIntervalTick()); + } + + public void stop() { + sent.onComplete(); + timeout.onComplete(); + if (intervalDisposable != null) { + intervalDisposable.dispose(); + } + } + + public void receive(Frame keepAliveFrame) { + this.lastReceivedMillis = System.currentTimeMillis(); + if (Frame.Keepalive.hasRespondFlag(keepAliveFrame)) { + doSend(Frame.Keepalive.from(Unpooled.wrappedBuffer(keepAliveFrame.getData()), false)); + } + } + + public Flux send() { + return sent; + } + + public Mono timeout() { + return timeout; + } + + abstract void onIntervalTick(); + + void doSend(Frame frame) { + sent.onNext(frame); + } + + void doCheckTimeout() { + long now = System.currentTimeMillis(); + if (now - lastReceivedMillis >= keepAlive.getTimeoutMillis()) { + timeout.onNext(keepAlive); + } + } + + private static class Server extends KeepAliveHandler { + + Server(KeepAlive keepAlive) { + super(keepAlive); + } + + @Override + void onIntervalTick() { + doCheckTimeout(); + } + } + + private static final class Client extends KeepAliveHandler { + + Client(KeepAlive keepAlive) { + super(keepAlive); + } + + @Override + void onIntervalTick() { + doCheckTimeout(); + doSend(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)); + } + } + + static final class KeepAlive { + private final long tickPeriod; + private final long timeoutMillis; + + KeepAlive(Duration tickPeriod, Duration timeoutMillis, int maxTicks) { + this.tickPeriod = tickPeriod.toMillis(); + this.timeoutMillis = timeoutMillis.toMillis() + maxTicks * tickPeriod.toMillis(); + } + + KeepAlive(long tickPeriod, long timeoutMillis) { + this.tickPeriod = tickPeriod; + this.timeoutMillis = timeoutMillis; + } + + public long getTickPeriod() { + return tickPeriod; + } + + public long getTimeoutMillis() { + return timeoutMillis; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index 4d779b2ad..8ce6ede18 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -16,7 +16,6 @@ package io.rsocket; -import io.netty.buffer.Unpooled; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; import io.rsocket.framing.FrameType; @@ -24,14 +23,11 @@ import io.rsocket.internal.UnboundedProcessor; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; -import javax.annotation.Nullable; import org.jctools.maps.NonBlockingHashMapLong; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; -import reactor.core.Disposable; import reactor.core.publisher.*; /** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */ @@ -44,13 +40,10 @@ class RSocketClient implements RSocket { private final MonoProcessor started; private final NonBlockingHashMapLong senders; private final NonBlockingHashMapLong> receivers; - private final AtomicInteger missedAckCounter; - private final UnboundedProcessor sendProcessor; + private KeepAliveHandler keepAliveHandler; - private @Nullable Disposable keepAliveSendSub; - private volatile long timeLastTickSentMs; - + /*server requester*/ RSocketClient( DuplexConnection connection, Function frameDecoder, @@ -59,7 +52,7 @@ class RSocketClient implements RSocket { this( connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0); } - + /*client requester*/ RSocketClient( DuplexConnection connection, Function frameDecoder, @@ -75,24 +68,29 @@ class RSocketClient implements RSocket { this.started = MonoProcessor.create(); this.senders = new NonBlockingHashMapLong<>(256); this.receivers = new NonBlockingHashMapLong<>(256); - this.missedAckCounter = new AtomicInteger(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); if (!Duration.ZERO.equals(tickPeriod)) { - long ackTimeoutMs = ackTimeout.toMillis(); - - this.keepAliveSendSub = - started - .thenMany(Flux.interval(tickPeriod)) - .doOnSubscribe(s -> timeLastTickSentMs = System.currentTimeMillis()) - .subscribe( - i -> sendKeepAlive(ackTimeoutMs, missedAcks), - t -> { - errorConsumer.accept(t); - connection.dispose(); - }); + this.keepAliveHandler = + KeepAliveHandler.ofClient( + new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks)); + + started.doOnTerminate(() -> keepAliveHandler.start()).subscribe(); + + keepAliveHandler + .timeout() + .subscribe( + keepAlive -> { + String message = + String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis()); + errorConsumer.accept(new ConnectionErrorException(message)); + connection.dispose(); + }); + keepAliveHandler.send().subscribe(sendProcessor::onNext); + } else { + keepAliveHandler = null; } connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer); @@ -140,22 +138,6 @@ private void handleSendProcessorCancel(SignalType t) { } } - private void sendKeepAlive(long ackTimeoutMs, int missedAcks) { - long now = System.currentTimeMillis(); - if (now - timeLastTickSentMs > ackTimeoutMs) { - int count = missedAckCounter.incrementAndGet(); - if (count >= missedAcks) { - String message = - String.format( - "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms", - count, missedAcks, ackTimeoutMs); - throw new ConnectionErrorException(message); - } - } - - sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)); - } - @Override public Mono fireAndForget(Payload payload) { Mono defer = @@ -380,6 +362,9 @@ private boolean contains(int streamId) { } protected void cleanup() { + if (keepAliveHandler != null) { + keepAliveHandler.stop(); + } try { for (UnicastProcessor subscriber : receivers.values()) { cleanUpSubscriber(subscriber); @@ -387,10 +372,6 @@ protected void cleanup() { for (LimitableRequestPublisher p : senders.values()) { cleanUpLimitableRequestPublisher(p); } - - if (null != keepAliveSendSub) { - keepAliveSendSub.dispose(); - } } finally { senders.clear(); receivers.clear(); @@ -437,8 +418,8 @@ private void handleStreamZero(FrameType type, Frame frame) { break; } case KEEPALIVE: - if (!Frame.Keepalive.hasRespondFlag(frame)) { - timeLastTickSentMs = System.currentTimeMillis(); + if (keepAliveHandler != null) { + keepAliveHandler.receive(frame); } break; default: diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 545dd863f..3926fe4cd 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -87,7 +87,7 @@ public static class ClientRSocketFactory implements ClientTransportAcceptor { private Payload setupPayload = EmptyPayload.INSTANCE; private Function frameDecoder = DefaultPayload::create; - private Duration tickPeriod = Duration.ZERO; + private Duration tickPeriod = Duration.ofSeconds(20); private Duration ackTimeout = Duration.ofSeconds(30); private int missedAcks = 3; @@ -109,8 +109,13 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { return this; } + /** + * Deprecated as Keep-Alive is not optional according to spec + * + * @return this ClientRSocketFactory + */ + @Deprecated public ClientRSocketFactory keepAlive() { - tickPeriod = Duration.ofSeconds(20); return this; } @@ -205,8 +210,8 @@ public Mono start() { Frame setupFrame = Frame.Setup.from( flags, - (int) ackTimeout.toMillis(), - (int) ackTimeout.toMillis() * missedAcks, + (int) tickPeriod.toMillis(), + (int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks), metadataMimeType, dataMimeType, setupPayload); @@ -339,6 +344,8 @@ private Mono processSetupFrame( } ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); + int keepAliveInterval = setupPayload.keepAliveInterval(); + int keepAliveMaxLifetime = setupPayload.keepAliveMaxLifetime(); RSocketClient rSocketClient = new RSocketClient( @@ -361,7 +368,9 @@ private Mono processSetupFrame( multiplexer.asClientConnection(), wrappedRSocketServer, frameDecoder, - errorConsumer); + errorConsumer, + keepAliveInterval, + keepAliveMaxLifetime); }) .doFinally(signalType -> setupPayload.release()) .then(); diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index eb00b6e3d..9ec273bf6 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -20,9 +20,8 @@ import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C; import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.framing.FrameType; import io.rsocket.internal.LimitableRequestPublisher; import io.rsocket.internal.UnboundedProcessor; @@ -33,10 +32,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.*; /** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */ class RSocketServer implements RSocket { @@ -45,35 +41,73 @@ class RSocketServer implements RSocket { private final RSocket requestHandler; private final Function frameDecoder; private final Consumer errorConsumer; + private final MonoProcessor started; private final NonBlockingHashMapLong sendingSubscriptions; private final NonBlockingHashMapLong> channelProcessors; private final UnboundedProcessor sendProcessor; private Disposable receiveDisposable; + private KeepAliveHandler keepAliveHandler; + /*client responder*/ RSocketServer( DuplexConnection connection, RSocket requestHandler, Function frameDecoder, Consumer errorConsumer) { + this(connection, requestHandler, frameDecoder, errorConsumer, 0, 0); + } + /*server responder*/ + RSocketServer( + DuplexConnection connection, + RSocket requestHandler, + Function frameDecoder, + Consumer errorConsumer, + long tickPeriod, + long ackTimeout) { this.connection = connection; this.requestHandler = requestHandler; this.frameDecoder = frameDecoder; this.errorConsumer = errorConsumer; this.sendingSubscriptions = new NonBlockingHashMapLong<>(); this.channelProcessors = new NonBlockingHashMapLong<>(); + this.started = MonoProcessor.create(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections this.sendProcessor = new UnboundedProcessor<>(); + if (tickPeriod != 0) { + keepAliveHandler = + KeepAliveHandler.ofServer(new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout)); + + started.doOnTerminate(() -> keepAliveHandler.start()).subscribe(); + + keepAliveHandler + .timeout() + .subscribe( + keepAlive -> { + String message = + String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis()); + errorConsumer.accept(new ConnectionErrorException(message)); + connection.dispose(); + }); + keepAliveHandler.send().subscribe(sendProcessor::onNext); + } else { + keepAliveHandler = null; + } + connection .send(sendProcessor) .doFinally(this::handleSendProcessorCancel) .subscribe(null, this::handleSendProcessorError); - this.receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); + this.receiveDisposable = + connection + .receive() + .doOnSubscribe(subscription -> started.onComplete()) + .subscribe(this::handleFrame, errorConsumer); this.connection .onClose() @@ -186,6 +220,9 @@ public Mono onClose() { } private void cleanup() { + if (keepAliveHandler != null) { + keepAliveHandler.stop(); + } cleanUpSendingSubscriptions(); cleanUpChannelProcessors(); @@ -302,7 +339,8 @@ private void handleRequestResponse(int streamId, Mono response) { payload.release(); return frame; }) - .switchIfEmpty(Mono.fromCallable(() -> Frame.PayloadFrame.from(streamId, FrameType.COMPLETE))) + .switchIfEmpty( + Mono.fromCallable(() -> Frame.PayloadFrame.from(streamId, FrameType.COMPLETE))) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) .subscribe(sendProcessor::onNext, t -> handleError(streamId, t)); } @@ -347,9 +385,8 @@ private void handleChannel(int streamId, Frame firstFrame) { } private void handleKeepAliveFrame(Frame frame) { - if (Frame.Keepalive.hasRespondFlag(frame)) { - ByteBuf data = Unpooled.wrappedBuffer(frame.getData()); - sendProcessor.onNext(Frame.Keepalive.from(data, false)); + if (keepAliveHandler != null) { + keepAliveHandler.receive(frame); } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index c1d9b0b9e..d1f34589e 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -23,11 +23,10 @@ import io.netty.util.AbstractReferenceCounted; import io.netty.util.Recycler; import io.rsocket.Payload; - -import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.Charset; +import javax.annotation.Nullable; public final class ByteBufPayload extends AbstractReferenceCounted implements Payload { private static final Recycler RECYCLER = diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index 8804e9438..54cc53cd1 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -155,7 +155,8 @@ public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { } public static Payload create(Payload payload) { - return create(copy(payload.sliceData()), payload.hasMetadata() ? copy(payload.sliceMetadata()) : null); + return create( + copy(payload.sliceData()), payload.hasMetadata() ? copy(payload.sliceMetadata()) : null); } private static ByteBuffer copy(ByteBuf byteBuf) { diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java new file mode 100644 index 000000000..abcddd37d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java @@ -0,0 +1,153 @@ +package io.rsocket; + +import io.netty.buffer.Unpooled; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.framing.FrameType; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class KeepAliveTest { + private static final int CLIENT_REQUESTER_TICK_PERIOD = 100; + private static final int CLIENT_REQUESTER_TIMEOUT = 700; + private static final int CLIENT_REQUESTER_MISSED_ACKS = 3; + private static final int SERVER_RESPONDER_TICK_PERIOD = 100; + private static final int SERVER_RESPONDER_TIMEOUT = 1000; + + @ParameterizedTest + @MethodSource("testData") + void keepAlives(Supplier testDataSupplier) { + TestData testData = testDataSupplier.get(); + TestDuplexConnection connection = testData.connection(); + + Flux.interval(Duration.ofMillis(100)) + .subscribe( + n -> connection.addToReceivedBuffer(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))); + + Mono.delay(Duration.ofMillis(1500)).block(); + + RSocket rSocket = testData.rSocket(); + List errors = testData.errors().errors(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + Assertions.assertThat(errors).isEmpty(); + } + + @ParameterizedTest + @MethodSource("testData") + void keepAlivesMissing(Supplier testDataSupplier) { + TestData testData = testDataSupplier.get(); + RSocket rSocket = testData.rSocket(); + + Mono.delay(Duration.ofMillis(1500)).block(); + + List errors = testData.errors().errors(); + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + Assertions.assertThat(errors).hasSize(1); + Throwable throwable = errors.get(0); + Assertions.assertThat(throwable).isInstanceOf(ConnectionErrorException.class); + } + + @Test + void clientRequesterRespondsToKeepAlives() { + TestData testData = requester(100, 700, 3).get(); + TestDuplexConnection connection = testData.connection(); + + Mono.delay(Duration.ofMillis(100)) + .subscribe( + l -> connection.addToReceivedBuffer(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true))); + + Mono keepAliveResponse = + Flux.from(connection.getSentAsPublisher()) + .filter(f -> f.getType() == FrameType.KEEPALIVE && !Frame.Keepalive.hasRespondFlag(f)) + .next() + .then(); + + StepVerifier.create(keepAliveResponse).expectComplete().verify(Duration.ofSeconds(5)); + } + + static Stream> testData() { + return Stream.of( + requester( + CLIENT_REQUESTER_TICK_PERIOD, CLIENT_REQUESTER_TIMEOUT, CLIENT_REQUESTER_MISSED_ACKS), + responder(SERVER_RESPONDER_TICK_PERIOD, SERVER_RESPONDER_TIMEOUT)); + } + + static Supplier requester(int tickPeriod, int timeout, int missedAcks) { + return () -> { + TestDuplexConnection connection = new TestDuplexConnection(); + Errors errors = new Errors(); + RSocketClient rSocket = + new RSocketClient( + connection, + DefaultPayload::create, + errors, + StreamIdSupplier.clientSupplier(), + Duration.ofMillis(tickPeriod), + Duration.ofMillis(timeout), + missedAcks); + return new TestData(rSocket, errors, connection); + }; + } + + static Supplier responder(int tickPeriod, int timeout) { + return () -> { + TestDuplexConnection connection = new TestDuplexConnection(); + AbstractRSocket handler = new AbstractRSocket() {}; + Errors errors = new Errors(); + RSocketServer rSocket = + new RSocketServer( + connection, handler, DefaultPayload::create, errors, tickPeriod, timeout); + return new TestData(rSocket, errors, connection); + }; + } + + static class TestData { + private final RSocket rSocket; + private final Errors errors; + private final TestDuplexConnection connection; + + public TestData(RSocket rSocket, Errors errors, TestDuplexConnection connection) { + this.rSocket = rSocket; + this.errors = errors; + this.connection = connection; + } + + public TestDuplexConnection connection() { + return connection; + } + + public RSocket rSocket() { + return rSocket; + } + + public Errors errors() { + return errors; + } + } + + static class Errors implements Consumer { + private final List errors = new ArrayList<>(); + + @Override + public void accept(Throwable throwable) { + errors.add(throwable); + } + + public List errors() { + return new ArrayList<>(errors); + } + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index c43cefed5..68d4f3fd3 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -9,14 +9,13 @@ import io.rsocket.transport.netty.server.NettyContextCloseable; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.time.Duration; -import java.util.function.Supplier; - public class InteractionsLoadTest { @Test @@ -24,33 +23,35 @@ public class InteractionsLoadTest { public void channel() { TcpServerTransport serverTransport = TcpServerTransport.create(0); - NettyContextCloseable server = RSocketFactory.receive() - .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() - .block(Duration.ofSeconds(10)); + NettyContextCloseable server = + RSocketFactory.receive() + .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket())) + .transport(serverTransport) + .start() + .block(Duration.ofSeconds(10)); TcpClientTransport transport = TcpClientTransport.create(server.address()); - RSocket client = RSocketFactory - .connect() - .transport(transport).start() - .block(Duration.ofSeconds(10)); + RSocket client = + RSocketFactory.connect().transport(transport).start().block(Duration.ofSeconds(10)); int concurrency = 16; Flux.range(1, concurrency) - .flatMap(v -> - client.requestChannel( - input().onBackpressureDrop().map(iv -> - DefaultPayload.create("foo"))) - .limitRate(10000), concurrency) + .flatMap( + v -> + client + .requestChannel( + input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) + .limitRate(10000), + concurrency) .timeout(Duration.ofSeconds(5)) - .doOnNext(p -> { - String data = p.getDataUtf8(); - if (!data.equals("bar")) { - throw new IllegalStateException("Channel Client Bad message: " + data); - } - }) + .doOnNext( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("bar")) { + throw new IllegalStateException("Channel Client Bad message: " + data); + } + }) .window(Duration.ofSeconds(1)) .flatMap(Flux::count) .doOnNext(d -> System.out.println("Got: " + d)) @@ -59,12 +60,10 @@ public void channel() { .subscribe(); server.onClose().block(); - } private static Flux input() { - Flux interval = Flux.interval(Duration.ofMillis(1)) - .onBackpressureDrop(); + Flux interval = Flux.interval(Duration.ofMillis(1)).onBackpressureDrop(); for (int i = 0; i < 10; i++) { interval = interval.mergeWith(interval); } @@ -74,31 +73,36 @@ private static Flux input() { private static class EchoRSocket extends AbstractRSocket { @Override public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads).map(p -> { - - String data = p.getDataUtf8(); - if (!data.equals("foo")) { - throw new IllegalStateException("Channel Server Bad message: " + data); - } - return DefaultPayload.create(DefaultPayload.create("bar")); - }); + return Flux.from(payloads) + .map( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("foo")) { + throw new IllegalStateException("Channel Server Bad message: " + data); + } + return DefaultPayload.create(DefaultPayload.create("bar")); + }); } @Override public Flux requestStream(Payload payload) { return Flux.just(payload) - .map(p -> { - String data = p.getDataUtf8(); - return data; - }) - .doOnNext((data) -> { - if (!data.equals("foo")) { - throw new IllegalStateException("Stream Server Bad message: " + data); - } - }).flatMap(data -> { - Supplier p = () -> DefaultPayload.create("bar"); - return Flux.range(1, 100).map(v -> p.get()); - }); + .map( + p -> { + String data = p.getDataUtf8(); + return data; + }) + .doOnNext( + (data) -> { + if (!data.equals("foo")) { + throw new IllegalStateException("Stream Server Bad message: " + data); + } + }) + .flatMap( + data -> { + Supplier p = () -> DefaultPayload.create("bar"); + return Flux.range(1, 100).map(v -> p.get()); + }); } } }