diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index 44460a938..8f0941285 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -21,8 +21,11 @@ import io.rsocket.framing.FrameType; import io.rsocket.internal.LimitableRequestPublisher; import io.rsocket.internal.UnboundedProcessor; + +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import org.jctools.maps.NonBlockingHashMapLong; @@ -72,7 +75,7 @@ class RSocketClient implements RSocket { // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); - connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer); + connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer); connection .send(sendProcessor) @@ -92,7 +95,9 @@ class RSocketClient implements RSocket { keepAlive -> { String message = String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis()); - errorConsumer.accept(new ConnectionErrorException(message)); + ConnectionErrorException err = new ConnectionErrorException(message); + lifecycle.terminate(err); + errorConsumer.accept(err); connection.dispose(); }); keepAliveHandler.send().subscribe(sendProcessor::onNext); @@ -157,12 +162,7 @@ public Flux requestChannel(Publisher payloads) { @Override public Mono metadataPush(Payload payload) { - return Mono.fromRunnable( - () -> { - final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); - payload.release(); - sendProcessor.onNext(requestFrame); - }); + return handleMetadataPush(payload); } @Override @@ -187,7 +187,7 @@ public Mono onClose() { private Mono handleFireAndForget(Payload payload) { return lifecycle - .started() + .active() .then( Mono.fromRunnable( () -> { @@ -201,7 +201,7 @@ private Mono handleFireAndForget(Payload payload) { private Flux handleRequestStream(final Payload payload) { return lifecycle - .started() + .active() .thenMany( Flux.defer( () -> { @@ -247,7 +247,7 @@ private Flux handleRequestStream(final Payload payload) { private Mono handleRequestResponse(final Payload payload) { return lifecycle - .started() + .active() .then( Mono.defer( () -> { @@ -274,7 +274,7 @@ private Mono handleRequestResponse(final Payload payload) { private Flux handleChannel(Flux request) { return lifecycle - .started() + .active() .thenMany( Flux.defer( () -> { @@ -365,11 +365,25 @@ private Flux handleChannel(Flux request) { })); } + private Mono handleMetadataPush(Payload payload) { + return lifecycle + .active() + .then(Mono.fromRunnable( + () -> { + final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); + payload.release(); + sendProcessor.onNext(requestFrame); + })); + } + private boolean contains(int streamId) { return receivers.containsKey(streamId); } - protected void cleanup() { + protected void terminate() { + + lifecycle.terminate(new ClosedChannelException()); + if (keepAliveHandler != null) { keepAliveHandler.dispose(); } @@ -397,13 +411,8 @@ private synchronized void cleanUpLimitableRequestPublisher( } private synchronized void cleanUpSubscriber(UnicastProcessor subscriber) { - Throwable err = lifecycle.terminationError(); try { - if (err != null) { - subscriber.onError(err); - } else { - subscriber.cancel(); - } + subscriber.onError(lifecycle.terminationError()); } catch (Throwable t) { errorConsumer.accept(t); } @@ -519,12 +528,12 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, Frame private static class Lifecycle { - private volatile Throwable terminationError; + private final AtomicReference terminationError = new AtomicReference<>(); - public Mono started() { + public Mono active() { return Mono.create( sink -> { - Throwable err = terminationError; + Throwable err = terminationError(); if (err == null) { sink.success(); } else { @@ -534,11 +543,11 @@ public Mono started() { } public void terminate(Throwable err) { - this.terminationError = err; + this.terminationError.compareAndSet(null, err); } public Throwable terminationError() { - return terminationError; + return terminationError.get(); } } } diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java new file mode 100644 index 000000000..a8e60d02b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java @@ -0,0 +1,64 @@ +package io.rsocket; + +import io.rsocket.RSocketClientTest.ClientSocketRule; +import io.rsocket.util.EmptyPayload; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Arrays; +import java.util.function.Function; + +@RunWith(Parameterized.class) +public class RSocketClientTerminationTest { + + @Rule + public final ClientSocketRule rule = new ClientSocketRule(); + private Function> interaction; + + public RSocketClientTerminationTest(Function> interaction) { + this.interaction = interaction; + } + + @Test + public void testCurrentStreamIsTerminatedOnConnectionClose() { + RSocketClient rSocket = rule.socket; + + Mono.delay(Duration.ofSeconds(1)) + .doOnNext(v -> rule.connection.dispose()) + .subscribe(); + + StepVerifier.create(interaction.apply(rSocket)) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + @Test + public void testSubsequentStreamIsTerminatedAfterConnectionClose() { + RSocketClient rSocket = rule.socket; + + rule.connection.dispose(); + StepVerifier.create(interaction.apply(rSocket)) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + @Parameterized.Parameters + public static Iterable>> rsocketInteractions() { + EmptyPayload payload = EmptyPayload.INSTANCE; + Publisher payloadStream = Flux.just(payload); + + Function> resp = rSocket -> rSocket.requestResponse(payload); + Function> stream = rSocket -> rSocket.requestStream(payload); + Function> channel = rSocket -> rSocket.requestChannel(payloadStream); + + return Arrays.asList(resp, stream, channel); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java index 4fcb65751..a153e2f63 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java @@ -215,7 +215,7 @@ protected RSocketClient newRSocket() { throwable -> errors.add(throwable), StreamIdSupplier.clientSupplier(), Duration.ofMillis(100), - Duration.ofMillis(100), + Duration.ofMillis(10_000), 4); }