diff --git a/build.gradle b/build.gradle
index 128ad7931..679b93135 100644
--- a/build.gradle
+++ b/build.gradle
@@ -50,6 +50,7 @@ subprojects {
dependencySet(group: 'org.junit.jupiter', version: '5.1.0') {
entry 'junit-jupiter-api'
entry 'junit-jupiter-engine'
+ entry 'junit-jupiter-params'
}
// TODO: Remove after JUnit5 migration
diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle
index d51da411f..0cdb6103c 100644
--- a/rsocket-core/build.gradle
+++ b/rsocket-core/build.gradle
@@ -36,6 +36,7 @@ dependencies {
testImplementation 'io.projectreactor:reactor-test'
testImplementation 'org.assertj:assertj-core'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
+ testImplementation 'org.junit.jupiter:junit-jupiter-params'
testImplementation 'org.mockito:mockito-core'
testRuntimeOnly 'ch.qos.logback:logback-classic'
diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java
index 559a1ff59..d88cfe445 100644
--- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java
+++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java
@@ -34,6 +34,10 @@ public static ConnectionSetupPayload create(final Frame setupFrame) {
return new DefaultConnectionSetupPayload(setupFrame);
}
+ public abstract int keepAliveInterval();
+
+ public abstract int keepAliveMaxLifetime();
+
public abstract String metadataMimeType();
public abstract String dataMimeType();
@@ -73,6 +77,16 @@ public DefaultConnectionSetupPayload(final Frame setupFrame) {
this.setupFrame = setupFrame;
}
+ @Override
+ public int keepAliveInterval() {
+ return SetupFrameFlyweight.keepaliveInterval(setupFrame.content());
+ }
+
+ @Override
+ public int keepAliveMaxLifetime() {
+ return SetupFrameFlyweight.maxLifetime(setupFrame.content());
+ }
+
@Override
public String metadataMimeType() {
return Setup.metadataMimeType(setupFrame);
diff --git a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java
new file mode 100644
index 000000000..eac560dfa
--- /dev/null
+++ b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java
@@ -0,0 +1,120 @@
+package io.rsocket;
+
+import io.netty.buffer.Unpooled;
+import java.time.Duration;
+import reactor.core.Disposable;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.MonoProcessor;
+import reactor.core.publisher.UnicastProcessor;
+
+abstract class KeepAliveHandler {
+ private final KeepAlive keepAlive;
+ private final UnicastProcessor sent = UnicastProcessor.create();
+ private final MonoProcessor timeout = MonoProcessor.create();
+ private final Flux interval;
+ private Disposable intervalDisposable;
+ private volatile long lastReceivedMillis;
+
+ static KeepAliveHandler ofServer(KeepAlive keepAlive) {
+ return new KeepAliveHandler.Server(keepAlive);
+ }
+
+ static KeepAliveHandler ofClient(KeepAlive keepAlive) {
+ return new KeepAliveHandler.Client(keepAlive);
+ }
+
+ private KeepAliveHandler(KeepAlive keepAlive) {
+ this.keepAlive = keepAlive;
+ this.interval = Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod()));
+ }
+
+ public void start() {
+ this.lastReceivedMillis = System.currentTimeMillis();
+ intervalDisposable = interval.subscribe(v -> onIntervalTick());
+ }
+
+ public void stop() {
+ sent.onComplete();
+ timeout.onComplete();
+ if (intervalDisposable != null) {
+ intervalDisposable.dispose();
+ }
+ }
+
+ public void receive(Frame keepAliveFrame) {
+ this.lastReceivedMillis = System.currentTimeMillis();
+ if (Frame.Keepalive.hasRespondFlag(keepAliveFrame)) {
+ doSend(Frame.Keepalive.from(Unpooled.wrappedBuffer(keepAliveFrame.getData()), false));
+ }
+ }
+
+ public Flux send() {
+ return sent;
+ }
+
+ public Mono timeout() {
+ return timeout;
+ }
+
+ abstract void onIntervalTick();
+
+ void doSend(Frame frame) {
+ sent.onNext(frame);
+ }
+
+ void doCheckTimeout() {
+ long now = System.currentTimeMillis();
+ if (now - lastReceivedMillis >= keepAlive.getTimeoutMillis()) {
+ timeout.onNext(keepAlive);
+ }
+ }
+
+ private static class Server extends KeepAliveHandler {
+
+ Server(KeepAlive keepAlive) {
+ super(keepAlive);
+ }
+
+ @Override
+ void onIntervalTick() {
+ doCheckTimeout();
+ }
+ }
+
+ private static final class Client extends KeepAliveHandler {
+
+ Client(KeepAlive keepAlive) {
+ super(keepAlive);
+ }
+
+ @Override
+ void onIntervalTick() {
+ doCheckTimeout();
+ doSend(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
+ }
+ }
+
+ static final class KeepAlive {
+ private final long tickPeriod;
+ private final long timeoutMillis;
+
+ KeepAlive(Duration tickPeriod, Duration timeoutMillis, int maxTicks) {
+ this.tickPeriod = tickPeriod.toMillis();
+ this.timeoutMillis = timeoutMillis.toMillis() + maxTicks * tickPeriod.toMillis();
+ }
+
+ KeepAlive(long tickPeriod, long timeoutMillis) {
+ this.tickPeriod = tickPeriod;
+ this.timeoutMillis = timeoutMillis;
+ }
+
+ public long getTickPeriod() {
+ return tickPeriod;
+ }
+
+ public long getTimeoutMillis() {
+ return timeoutMillis;
+ }
+ }
+}
diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java
index 4d779b2ad..8ce6ede18 100644
--- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java
+++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java
@@ -16,7 +16,6 @@
package io.rsocket;
-import io.netty.buffer.Unpooled;
import io.rsocket.exceptions.ConnectionErrorException;
import io.rsocket.exceptions.Exceptions;
import io.rsocket.framing.FrameType;
@@ -24,14 +23,11 @@
import io.rsocket.internal.UnboundedProcessor;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
-import javax.annotation.Nullable;
import org.jctools.maps.NonBlockingHashMapLong;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
-import reactor.core.Disposable;
import reactor.core.publisher.*;
/** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */
@@ -44,13 +40,10 @@ class RSocketClient implements RSocket {
private final MonoProcessor started;
private final NonBlockingHashMapLong senders;
private final NonBlockingHashMapLong> receivers;
- private final AtomicInteger missedAckCounter;
-
private final UnboundedProcessor sendProcessor;
+ private KeepAliveHandler keepAliveHandler;
- private @Nullable Disposable keepAliveSendSub;
- private volatile long timeLastTickSentMs;
-
+ /*server requester*/
RSocketClient(
DuplexConnection connection,
Function frameDecoder,
@@ -59,7 +52,7 @@ class RSocketClient implements RSocket {
this(
connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0);
}
-
+ /*client requester*/
RSocketClient(
DuplexConnection connection,
Function frameDecoder,
@@ -75,24 +68,29 @@ class RSocketClient implements RSocket {
this.started = MonoProcessor.create();
this.senders = new NonBlockingHashMapLong<>(256);
this.receivers = new NonBlockingHashMapLong<>(256);
- this.missedAckCounter = new AtomicInteger();
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
this.sendProcessor = new UnboundedProcessor<>();
if (!Duration.ZERO.equals(tickPeriod)) {
- long ackTimeoutMs = ackTimeout.toMillis();
-
- this.keepAliveSendSub =
- started
- .thenMany(Flux.interval(tickPeriod))
- .doOnSubscribe(s -> timeLastTickSentMs = System.currentTimeMillis())
- .subscribe(
- i -> sendKeepAlive(ackTimeoutMs, missedAcks),
- t -> {
- errorConsumer.accept(t);
- connection.dispose();
- });
+ this.keepAliveHandler =
+ KeepAliveHandler.ofClient(
+ new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks));
+
+ started.doOnTerminate(() -> keepAliveHandler.start()).subscribe();
+
+ keepAliveHandler
+ .timeout()
+ .subscribe(
+ keepAlive -> {
+ String message =
+ String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
+ errorConsumer.accept(new ConnectionErrorException(message));
+ connection.dispose();
+ });
+ keepAliveHandler.send().subscribe(sendProcessor::onNext);
+ } else {
+ keepAliveHandler = null;
}
connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer);
@@ -140,22 +138,6 @@ private void handleSendProcessorCancel(SignalType t) {
}
}
- private void sendKeepAlive(long ackTimeoutMs, int missedAcks) {
- long now = System.currentTimeMillis();
- if (now - timeLastTickSentMs > ackTimeoutMs) {
- int count = missedAckCounter.incrementAndGet();
- if (count >= missedAcks) {
- String message =
- String.format(
- "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
- count, missedAcks, ackTimeoutMs);
- throw new ConnectionErrorException(message);
- }
- }
-
- sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
- }
-
@Override
public Mono fireAndForget(Payload payload) {
Mono defer =
@@ -380,6 +362,9 @@ private boolean contains(int streamId) {
}
protected void cleanup() {
+ if (keepAliveHandler != null) {
+ keepAliveHandler.stop();
+ }
try {
for (UnicastProcessor subscriber : receivers.values()) {
cleanUpSubscriber(subscriber);
@@ -387,10 +372,6 @@ protected void cleanup() {
for (LimitableRequestPublisher p : senders.values()) {
cleanUpLimitableRequestPublisher(p);
}
-
- if (null != keepAliveSendSub) {
- keepAliveSendSub.dispose();
- }
} finally {
senders.clear();
receivers.clear();
@@ -437,8 +418,8 @@ private void handleStreamZero(FrameType type, Frame frame) {
break;
}
case KEEPALIVE:
- if (!Frame.Keepalive.hasRespondFlag(frame)) {
- timeLastTickSentMs = System.currentTimeMillis();
+ if (keepAliveHandler != null) {
+ keepAliveHandler.receive(frame);
}
break;
default:
diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java
index 545dd863f..3926fe4cd 100644
--- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java
+++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java
@@ -87,7 +87,7 @@ public static class ClientRSocketFactory implements ClientTransportAcceptor {
private Payload setupPayload = EmptyPayload.INSTANCE;
private Function frameDecoder = DefaultPayload::create;
- private Duration tickPeriod = Duration.ZERO;
+ private Duration tickPeriod = Duration.ofSeconds(20);
private Duration ackTimeout = Duration.ofSeconds(30);
private int missedAcks = 3;
@@ -109,8 +109,13 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) {
return this;
}
+ /**
+ * Deprecated as Keep-Alive is not optional according to spec
+ *
+ * @return this ClientRSocketFactory
+ */
+ @Deprecated
public ClientRSocketFactory keepAlive() {
- tickPeriod = Duration.ofSeconds(20);
return this;
}
@@ -205,8 +210,8 @@ public Mono start() {
Frame setupFrame =
Frame.Setup.from(
flags,
- (int) ackTimeout.toMillis(),
- (int) ackTimeout.toMillis() * missedAcks,
+ (int) tickPeriod.toMillis(),
+ (int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks),
metadataMimeType,
dataMimeType,
setupPayload);
@@ -339,6 +344,8 @@ private Mono processSetupFrame(
}
ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame);
+ int keepAliveInterval = setupPayload.keepAliveInterval();
+ int keepAliveMaxLifetime = setupPayload.keepAliveMaxLifetime();
RSocketClient rSocketClient =
new RSocketClient(
@@ -361,7 +368,9 @@ private Mono processSetupFrame(
multiplexer.asClientConnection(),
wrappedRSocketServer,
frameDecoder,
- errorConsumer);
+ errorConsumer,
+ keepAliveInterval,
+ keepAliveMaxLifetime);
})
.doFinally(signalType -> setupPayload.release())
.then();
diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java
index eb00b6e3d..9ec273bf6 100644
--- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java
+++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java
@@ -20,9 +20,8 @@
import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C;
import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
import io.rsocket.exceptions.ApplicationErrorException;
+import io.rsocket.exceptions.ConnectionErrorException;
import io.rsocket.framing.FrameType;
import io.rsocket.internal.LimitableRequestPublisher;
import io.rsocket.internal.UnboundedProcessor;
@@ -33,10 +32,7 @@
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.Disposable;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-import reactor.core.publisher.SignalType;
-import reactor.core.publisher.UnicastProcessor;
+import reactor.core.publisher.*;
/** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */
class RSocketServer implements RSocket {
@@ -45,35 +41,73 @@ class RSocketServer implements RSocket {
private final RSocket requestHandler;
private final Function frameDecoder;
private final Consumer errorConsumer;
+ private final MonoProcessor started;
private final NonBlockingHashMapLong sendingSubscriptions;
private final NonBlockingHashMapLong> channelProcessors;
private final UnboundedProcessor sendProcessor;
private Disposable receiveDisposable;
+ private KeepAliveHandler keepAliveHandler;
+ /*client responder*/
RSocketServer(
DuplexConnection connection,
RSocket requestHandler,
Function frameDecoder,
Consumer errorConsumer) {
+ this(connection, requestHandler, frameDecoder, errorConsumer, 0, 0);
+ }
+ /*server responder*/
+ RSocketServer(
+ DuplexConnection connection,
+ RSocket requestHandler,
+ Function frameDecoder,
+ Consumer errorConsumer,
+ long tickPeriod,
+ long ackTimeout) {
this.connection = connection;
this.requestHandler = requestHandler;
this.frameDecoder = frameDecoder;
this.errorConsumer = errorConsumer;
this.sendingSubscriptions = new NonBlockingHashMapLong<>();
this.channelProcessors = new NonBlockingHashMapLong<>();
+ this.started = MonoProcessor.create();
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
// connections
this.sendProcessor = new UnboundedProcessor<>();
+ if (tickPeriod != 0) {
+ keepAliveHandler =
+ KeepAliveHandler.ofServer(new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout));
+
+ started.doOnTerminate(() -> keepAliveHandler.start()).subscribe();
+
+ keepAliveHandler
+ .timeout()
+ .subscribe(
+ keepAlive -> {
+ String message =
+ String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
+ errorConsumer.accept(new ConnectionErrorException(message));
+ connection.dispose();
+ });
+ keepAliveHandler.send().subscribe(sendProcessor::onNext);
+ } else {
+ keepAliveHandler = null;
+ }
+
connection
.send(sendProcessor)
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);
- this.receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer);
+ this.receiveDisposable =
+ connection
+ .receive()
+ .doOnSubscribe(subscription -> started.onComplete())
+ .subscribe(this::handleFrame, errorConsumer);
this.connection
.onClose()
@@ -186,6 +220,9 @@ public Mono onClose() {
}
private void cleanup() {
+ if (keepAliveHandler != null) {
+ keepAliveHandler.stop();
+ }
cleanUpSendingSubscriptions();
cleanUpChannelProcessors();
@@ -302,7 +339,8 @@ private void handleRequestResponse(int streamId, Mono response) {
payload.release();
return frame;
})
- .switchIfEmpty(Mono.fromCallable(() -> Frame.PayloadFrame.from(streamId, FrameType.COMPLETE)))
+ .switchIfEmpty(
+ Mono.fromCallable(() -> Frame.PayloadFrame.from(streamId, FrameType.COMPLETE)))
.doFinally(signalType -> sendingSubscriptions.remove(streamId))
.subscribe(sendProcessor::onNext, t -> handleError(streamId, t));
}
@@ -347,9 +385,8 @@ private void handleChannel(int streamId, Frame firstFrame) {
}
private void handleKeepAliveFrame(Frame frame) {
- if (Frame.Keepalive.hasRespondFlag(frame)) {
- ByteBuf data = Unpooled.wrappedBuffer(frame.getData());
- sendProcessor.onNext(Frame.Keepalive.from(data, false));
+ if (keepAliveHandler != null) {
+ keepAliveHandler.receive(frame);
}
}
diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java
index c1d9b0b9e..d1f34589e 100644
--- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java
+++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java
@@ -23,11 +23,10 @@
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.Recycler;
import io.rsocket.Payload;
-
-import javax.annotation.Nullable;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
+import javax.annotation.Nullable;
public final class ByteBufPayload extends AbstractReferenceCounted implements Payload {
private static final Recycler RECYCLER =
diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java
index 8804e9438..54cc53cd1 100644
--- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java
+++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java
@@ -155,7 +155,8 @@ public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) {
}
public static Payload create(Payload payload) {
- return create(copy(payload.sliceData()), payload.hasMetadata() ? copy(payload.sliceMetadata()) : null);
+ return create(
+ copy(payload.sliceData()), payload.hasMetadata() ? copy(payload.sliceMetadata()) : null);
}
private static ByteBuffer copy(ByteBuf byteBuf) {
diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java
new file mode 100644
index 000000000..abcddd37d
--- /dev/null
+++ b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java
@@ -0,0 +1,153 @@
+package io.rsocket;
+
+import io.netty.buffer.Unpooled;
+import io.rsocket.exceptions.ConnectionErrorException;
+import io.rsocket.framing.FrameType;
+import io.rsocket.test.util.TestDuplexConnection;
+import io.rsocket.util.DefaultPayload;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+import java.util.stream.Stream;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+public class KeepAliveTest {
+ private static final int CLIENT_REQUESTER_TICK_PERIOD = 100;
+ private static final int CLIENT_REQUESTER_TIMEOUT = 700;
+ private static final int CLIENT_REQUESTER_MISSED_ACKS = 3;
+ private static final int SERVER_RESPONDER_TICK_PERIOD = 100;
+ private static final int SERVER_RESPONDER_TIMEOUT = 1000;
+
+ @ParameterizedTest
+ @MethodSource("testData")
+ void keepAlives(Supplier testDataSupplier) {
+ TestData testData = testDataSupplier.get();
+ TestDuplexConnection connection = testData.connection();
+
+ Flux.interval(Duration.ofMillis(100))
+ .subscribe(
+ n -> connection.addToReceivedBuffer(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)));
+
+ Mono.delay(Duration.ofMillis(1500)).block();
+
+ RSocket rSocket = testData.rSocket();
+ List errors = testData.errors().errors();
+
+ Assertions.assertThat(rSocket.isDisposed()).isFalse();
+ Assertions.assertThat(errors).isEmpty();
+ }
+
+ @ParameterizedTest
+ @MethodSource("testData")
+ void keepAlivesMissing(Supplier testDataSupplier) {
+ TestData testData = testDataSupplier.get();
+ RSocket rSocket = testData.rSocket();
+
+ Mono.delay(Duration.ofMillis(1500)).block();
+
+ List errors = testData.errors().errors();
+ Assertions.assertThat(rSocket.isDisposed()).isTrue();
+ Assertions.assertThat(errors).hasSize(1);
+ Throwable throwable = errors.get(0);
+ Assertions.assertThat(throwable).isInstanceOf(ConnectionErrorException.class);
+ }
+
+ @Test
+ void clientRequesterRespondsToKeepAlives() {
+ TestData testData = requester(100, 700, 3).get();
+ TestDuplexConnection connection = testData.connection();
+
+ Mono.delay(Duration.ofMillis(100))
+ .subscribe(
+ l -> connection.addToReceivedBuffer(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)));
+
+ Mono keepAliveResponse =
+ Flux.from(connection.getSentAsPublisher())
+ .filter(f -> f.getType() == FrameType.KEEPALIVE && !Frame.Keepalive.hasRespondFlag(f))
+ .next()
+ .then();
+
+ StepVerifier.create(keepAliveResponse).expectComplete().verify(Duration.ofSeconds(5));
+ }
+
+ static Stream> testData() {
+ return Stream.of(
+ requester(
+ CLIENT_REQUESTER_TICK_PERIOD, CLIENT_REQUESTER_TIMEOUT, CLIENT_REQUESTER_MISSED_ACKS),
+ responder(SERVER_RESPONDER_TICK_PERIOD, SERVER_RESPONDER_TIMEOUT));
+ }
+
+ static Supplier requester(int tickPeriod, int timeout, int missedAcks) {
+ return () -> {
+ TestDuplexConnection connection = new TestDuplexConnection();
+ Errors errors = new Errors();
+ RSocketClient rSocket =
+ new RSocketClient(
+ connection,
+ DefaultPayload::create,
+ errors,
+ StreamIdSupplier.clientSupplier(),
+ Duration.ofMillis(tickPeriod),
+ Duration.ofMillis(timeout),
+ missedAcks);
+ return new TestData(rSocket, errors, connection);
+ };
+ }
+
+ static Supplier responder(int tickPeriod, int timeout) {
+ return () -> {
+ TestDuplexConnection connection = new TestDuplexConnection();
+ AbstractRSocket handler = new AbstractRSocket() {};
+ Errors errors = new Errors();
+ RSocketServer rSocket =
+ new RSocketServer(
+ connection, handler, DefaultPayload::create, errors, tickPeriod, timeout);
+ return new TestData(rSocket, errors, connection);
+ };
+ }
+
+ static class TestData {
+ private final RSocket rSocket;
+ private final Errors errors;
+ private final TestDuplexConnection connection;
+
+ public TestData(RSocket rSocket, Errors errors, TestDuplexConnection connection) {
+ this.rSocket = rSocket;
+ this.errors = errors;
+ this.connection = connection;
+ }
+
+ public TestDuplexConnection connection() {
+ return connection;
+ }
+
+ public RSocket rSocket() {
+ return rSocket;
+ }
+
+ public Errors errors() {
+ return errors;
+ }
+ }
+
+ static class Errors implements Consumer {
+ private final List errors = new ArrayList<>();
+
+ @Override
+ public void accept(Throwable throwable) {
+ errors.add(throwable);
+ }
+
+ public List errors() {
+ return new ArrayList<>(errors);
+ }
+ }
+}
diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java
index c43cefed5..68d4f3fd3 100644
--- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java
+++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java
@@ -9,14 +9,13 @@
import io.rsocket.transport.netty.server.NettyContextCloseable;
import io.rsocket.transport.netty.server.TcpServerTransport;
import io.rsocket.util.DefaultPayload;
+import java.time.Duration;
+import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import java.time.Duration;
-import java.util.function.Supplier;
-
public class InteractionsLoadTest {
@Test
@@ -24,33 +23,35 @@ public class InteractionsLoadTest {
public void channel() {
TcpServerTransport serverTransport = TcpServerTransport.create(0);
- NettyContextCloseable server = RSocketFactory.receive()
- .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket()))
- .transport(serverTransport)
- .start()
- .block(Duration.ofSeconds(10));
+ NettyContextCloseable server =
+ RSocketFactory.receive()
+ .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket()))
+ .transport(serverTransport)
+ .start()
+ .block(Duration.ofSeconds(10));
TcpClientTransport transport = TcpClientTransport.create(server.address());
- RSocket client = RSocketFactory
- .connect()
- .transport(transport).start()
- .block(Duration.ofSeconds(10));
+ RSocket client =
+ RSocketFactory.connect().transport(transport).start().block(Duration.ofSeconds(10));
int concurrency = 16;
Flux.range(1, concurrency)
- .flatMap(v ->
- client.requestChannel(
- input().onBackpressureDrop().map(iv ->
- DefaultPayload.create("foo")))
- .limitRate(10000), concurrency)
+ .flatMap(
+ v ->
+ client
+ .requestChannel(
+ input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo")))
+ .limitRate(10000),
+ concurrency)
.timeout(Duration.ofSeconds(5))
- .doOnNext(p -> {
- String data = p.getDataUtf8();
- if (!data.equals("bar")) {
- throw new IllegalStateException("Channel Client Bad message: " + data);
- }
- })
+ .doOnNext(
+ p -> {
+ String data = p.getDataUtf8();
+ if (!data.equals("bar")) {
+ throw new IllegalStateException("Channel Client Bad message: " + data);
+ }
+ })
.window(Duration.ofSeconds(1))
.flatMap(Flux::count)
.doOnNext(d -> System.out.println("Got: " + d))
@@ -59,12 +60,10 @@ public void channel() {
.subscribe();
server.onClose().block();
-
}
private static Flux input() {
- Flux interval = Flux.interval(Duration.ofMillis(1))
- .onBackpressureDrop();
+ Flux interval = Flux.interval(Duration.ofMillis(1)).onBackpressureDrop();
for (int i = 0; i < 10; i++) {
interval = interval.mergeWith(interval);
}
@@ -74,31 +73,36 @@ private static Flux input() {
private static class EchoRSocket extends AbstractRSocket {
@Override
public Flux requestChannel(Publisher payloads) {
- return Flux.from(payloads).map(p -> {
-
- String data = p.getDataUtf8();
- if (!data.equals("foo")) {
- throw new IllegalStateException("Channel Server Bad message: " + data);
- }
- return DefaultPayload.create(DefaultPayload.create("bar"));
- });
+ return Flux.from(payloads)
+ .map(
+ p -> {
+ String data = p.getDataUtf8();
+ if (!data.equals("foo")) {
+ throw new IllegalStateException("Channel Server Bad message: " + data);
+ }
+ return DefaultPayload.create(DefaultPayload.create("bar"));
+ });
}
@Override
public Flux requestStream(Payload payload) {
return Flux.just(payload)
- .map(p -> {
- String data = p.getDataUtf8();
- return data;
- })
- .doOnNext((data) -> {
- if (!data.equals("foo")) {
- throw new IllegalStateException("Stream Server Bad message: " + data);
- }
- }).flatMap(data -> {
- Supplier p = () -> DefaultPayload.create("bar");
- return Flux.range(1, 100).map(v -> p.get());
- });
+ .map(
+ p -> {
+ String data = p.getDataUtf8();
+ return data;
+ })
+ .doOnNext(
+ (data) -> {
+ if (!data.equals("foo")) {
+ throw new IllegalStateException("Stream Server Bad message: " + data);
+ }
+ })
+ .flatMap(
+ data -> {
+ Supplier p = () -> DefaultPayload.create("bar");
+ return Flux.range(1, 100).map(v -> p.get());
+ });
}
}
}