Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ subprojects {
dependencySet(group: 'org.junit.jupiter', version: '5.1.0') {
entry 'junit-jupiter-api'
entry 'junit-jupiter-engine'
entry 'junit-jupiter-params'
}

// TODO: Remove after JUnit5 migration
Expand Down
1 change: 1 addition & 0 deletions rsocket-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies {
testImplementation 'io.projectreactor:reactor-test'
testImplementation 'org.assertj:assertj-core'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
testImplementation 'org.junit.jupiter:junit-jupiter-params'
testImplementation 'org.mockito:mockito-core'

testRuntimeOnly 'ch.qos.logback:logback-classic'
Expand Down
14 changes: 14 additions & 0 deletions rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ public static ConnectionSetupPayload create(final Frame setupFrame) {
return new DefaultConnectionSetupPayload(setupFrame);
}

public abstract int keepAliveInterval();

public abstract int keepAliveMaxLifetime();

public abstract String metadataMimeType();

public abstract String dataMimeType();
Expand Down Expand Up @@ -73,6 +77,16 @@ public DefaultConnectionSetupPayload(final Frame setupFrame) {
this.setupFrame = setupFrame;
}

@Override
public int keepAliveInterval() {
return SetupFrameFlyweight.keepaliveInterval(setupFrame.content());
}

@Override
public int keepAliveMaxLifetime() {
return SetupFrameFlyweight.maxLifetime(setupFrame.content());
}

@Override
public String metadataMimeType() {
return Setup.metadataMimeType(setupFrame);
Expand Down
120 changes: 120 additions & 0 deletions rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package io.rsocket;

import io.netty.buffer.Unpooled;
import java.time.Duration;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.UnicastProcessor;

abstract class KeepAliveHandler {
private final KeepAlive keepAlive;
private final UnicastProcessor<Frame> sent = UnicastProcessor.create();
private final MonoProcessor<KeepAlive> timeout = MonoProcessor.create();
private final Flux<Long> interval;
private Disposable intervalDisposable;
private volatile long lastReceivedMillis;

static KeepAliveHandler ofServer(KeepAlive keepAlive) {
return new KeepAliveHandler.Server(keepAlive);
}

static KeepAliveHandler ofClient(KeepAlive keepAlive) {
return new KeepAliveHandler.Client(keepAlive);
}

private KeepAliveHandler(KeepAlive keepAlive) {
this.keepAlive = keepAlive;
this.interval = Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod()));
}

public void start() {
this.lastReceivedMillis = System.currentTimeMillis();
intervalDisposable = interval.subscribe(v -> onIntervalTick());
}

public void stop() {
sent.onComplete();
timeout.onComplete();
if (intervalDisposable != null) {
intervalDisposable.dispose();
}
}

public void receive(Frame keepAliveFrame) {
this.lastReceivedMillis = System.currentTimeMillis();
if (Frame.Keepalive.hasRespondFlag(keepAliveFrame)) {
doSend(Frame.Keepalive.from(Unpooled.wrappedBuffer(keepAliveFrame.getData()), false));
}
}

public Flux<Frame> send() {
return sent;
}

public Mono<KeepAlive> timeout() {
return timeout;
}

abstract void onIntervalTick();

void doSend(Frame frame) {
sent.onNext(frame);
}

void doCheckTimeout() {
long now = System.currentTimeMillis();
if (now - lastReceivedMillis >= keepAlive.getTimeoutMillis()) {
timeout.onNext(keepAlive);
}
}

private static class Server extends KeepAliveHandler {

Server(KeepAlive keepAlive) {
super(keepAlive);
}

@Override
void onIntervalTick() {
doCheckTimeout();
}
}

private static final class Client extends KeepAliveHandler {

Client(KeepAlive keepAlive) {
super(keepAlive);
}

@Override
void onIntervalTick() {
doCheckTimeout();
doSend(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
}
}

static final class KeepAlive {
private final long tickPeriod;
private final long timeoutMillis;

KeepAlive(Duration tickPeriod, Duration timeoutMillis, int maxTicks) {
this.tickPeriod = tickPeriod.toMillis();
this.timeoutMillis = timeoutMillis.toMillis() + maxTicks * tickPeriod.toMillis();
}

KeepAlive(long tickPeriod, long timeoutMillis) {
this.tickPeriod = tickPeriod;
this.timeoutMillis = timeoutMillis;
}

public long getTickPeriod() {
return tickPeriod;
}

public long getTimeoutMillis() {
return timeoutMillis;
}
}
}
71 changes: 26 additions & 45 deletions rsocket-core/src/main/java/io/rsocket/RSocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,18 @@

package io.rsocket;

import io.netty.buffer.Unpooled;
import io.rsocket.exceptions.ConnectionErrorException;
import io.rsocket.exceptions.Exceptions;
import io.rsocket.framing.FrameType;
import io.rsocket.internal.LimitableRequestPublisher;
import io.rsocket.internal.UnboundedProcessor;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.jctools.maps.NonBlockingHashMapLong;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.Disposable;
import reactor.core.publisher.*;

/** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */
Expand All @@ -44,13 +40,10 @@ class RSocketClient implements RSocket {
private final MonoProcessor<Void> started;
private final NonBlockingHashMapLong<LimitableRequestPublisher> senders;
private final NonBlockingHashMapLong<UnicastProcessor<Payload>> receivers;
private final AtomicInteger missedAckCounter;

private final UnboundedProcessor<Frame> sendProcessor;
private KeepAliveHandler keepAliveHandler;

private @Nullable Disposable keepAliveSendSub;
private volatile long timeLastTickSentMs;

/*server requester*/
RSocketClient(
DuplexConnection connection,
Function<Frame, ? extends Payload> frameDecoder,
Expand All @@ -59,7 +52,7 @@ class RSocketClient implements RSocket {
this(
connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0);
}

/*client requester*/
RSocketClient(
DuplexConnection connection,
Function<Frame, ? extends Payload> frameDecoder,
Expand All @@ -75,24 +68,29 @@ class RSocketClient implements RSocket {
this.started = MonoProcessor.create();
this.senders = new NonBlockingHashMapLong<>(256);
this.receivers = new NonBlockingHashMapLong<>(256);
this.missedAckCounter = new AtomicInteger();

// DO NOT Change the order here. The Send processor must be subscribed to before receiving
this.sendProcessor = new UnboundedProcessor<>();

if (!Duration.ZERO.equals(tickPeriod)) {
long ackTimeoutMs = ackTimeout.toMillis();

this.keepAliveSendSub =
started
.thenMany(Flux.interval(tickPeriod))
.doOnSubscribe(s -> timeLastTickSentMs = System.currentTimeMillis())
.subscribe(
i -> sendKeepAlive(ackTimeoutMs, missedAcks),
t -> {
errorConsumer.accept(t);
connection.dispose();
});
this.keepAliveHandler =
KeepAliveHandler.ofClient(
new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks));

started.doOnTerminate(() -> keepAliveHandler.start()).subscribe();

keepAliveHandler
.timeout()
.subscribe(
keepAlive -> {
String message =
String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
errorConsumer.accept(new ConnectionErrorException(message));
connection.dispose();
});
keepAliveHandler.send().subscribe(sendProcessor::onNext);
} else {
keepAliveHandler = null;
}

connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer);
Expand Down Expand Up @@ -140,22 +138,6 @@ private void handleSendProcessorCancel(SignalType t) {
}
}

private void sendKeepAlive(long ackTimeoutMs, int missedAcks) {
long now = System.currentTimeMillis();
if (now - timeLastTickSentMs > ackTimeoutMs) {
int count = missedAckCounter.incrementAndGet();
if (count >= missedAcks) {
String message =
String.format(
"Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms",
count, missedAcks, ackTimeoutMs);
throw new ConnectionErrorException(message);
}
}

sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true));
}

@Override
public Mono<Void> fireAndForget(Payload payload) {
Mono<Void> defer =
Expand Down Expand Up @@ -380,17 +362,16 @@ private boolean contains(int streamId) {
}

protected void cleanup() {
if (keepAliveHandler != null) {
keepAliveHandler.stop();
}
try {
for (UnicastProcessor<Payload> subscriber : receivers.values()) {
cleanUpSubscriber(subscriber);
}
for (LimitableRequestPublisher p : senders.values()) {
cleanUpLimitableRequestPublisher(p);
}

if (null != keepAliveSendSub) {
keepAliveSendSub.dispose();
}
} finally {
senders.clear();
receivers.clear();
Expand Down Expand Up @@ -437,8 +418,8 @@ private void handleStreamZero(FrameType type, Frame frame) {
break;
}
case KEEPALIVE:
if (!Frame.Keepalive.hasRespondFlag(frame)) {
timeLastTickSentMs = System.currentTimeMillis();
if (keepAliveHandler != null) {
keepAliveHandler.receive(frame);
}
break;
default:
Expand Down
19 changes: 14 additions & 5 deletions rsocket-core/src/main/java/io/rsocket/RSocketFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static class ClientRSocketFactory implements ClientTransportAcceptor {
private Payload setupPayload = EmptyPayload.INSTANCE;
private Function<Frame, ? extends Payload> frameDecoder = DefaultPayload::create;

private Duration tickPeriod = Duration.ZERO;
private Duration tickPeriod = Duration.ofSeconds(20);
private Duration ackTimeout = Duration.ofSeconds(30);
private int missedAcks = 3;

Expand All @@ -109,8 +109,13 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) {
return this;
}

/**
* Deprecated as Keep-Alive is not optional according to spec
*
* @return this ClientRSocketFactory
*/
@Deprecated
public ClientRSocketFactory keepAlive() {
tickPeriod = Duration.ofSeconds(20);
return this;
}

Expand Down Expand Up @@ -205,8 +210,8 @@ public Mono<RSocket> start() {
Frame setupFrame =
Frame.Setup.from(
flags,
(int) ackTimeout.toMillis(),
(int) ackTimeout.toMillis() * missedAcks,
(int) tickPeriod.toMillis(),
(int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks),
metadataMimeType,
dataMimeType,
setupPayload);
Expand Down Expand Up @@ -339,6 +344,8 @@ private Mono<Void> processSetupFrame(
}

ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame);
int keepAliveInterval = setupPayload.keepAliveInterval();
int keepAliveMaxLifetime = setupPayload.keepAliveMaxLifetime();

RSocketClient rSocketClient =
new RSocketClient(
Expand All @@ -361,7 +368,9 @@ private Mono<Void> processSetupFrame(
multiplexer.asClientConnection(),
wrappedRSocketServer,
frameDecoder,
errorConsumer);
errorConsumer,
keepAliveInterval,
keepAliveMaxLifetime);
})
.doFinally(signalType -> setupPayload.release())
.then();
Expand Down
Loading