diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index f7fb161fd..d9da1017b 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -189,7 +189,7 @@ public void dispose() { @Override public boolean isDisposed() { - return onClose.isDisposed(); + return terminationError != null; } @Override @@ -222,6 +222,12 @@ private Mono handleFireAndForget(Payload payload) { new IllegalStateException("FireAndForgetMono allows only a single subscriber")); } + if (isDisposed()) { + payload.release(); + final Throwable t = terminationError; + return Mono.error(t); + } + final int streamId = streamIdSupplier.nextStreamId(receivers); final ByteBuf requestFrame = RequestFireAndForgetFrameCodec.encodeReleasingPayload( @@ -270,6 +276,13 @@ private Mono handleRequestResponse(final Payload payload) { @Override void hookOnFirstRequest(long n) { + if (isDisposed()) { + payload.release(); + final Throwable t = terminationError; + receiver.onError(t); + return; + } + int streamId = streamIdSupplier.nextStreamId(receivers); this.streamId = streamId; @@ -335,6 +348,13 @@ private Flux handleRequestStream(final Payload payload) { @Override void hookOnFirstRequest(long n) { + if (isDisposed()) { + payload.release(); + final Throwable t = terminationError; + receiver.onError(t); + return; + } + int streamId = streamIdSupplier.nextStreamId(receivers); this.streamId = streamId; @@ -477,6 +497,14 @@ protected void hookFinally(SignalType type) { @Override void hookOnFirstRequest(long n) { + if (isDisposed()) { + initialPayload.release(); + final Throwable t = terminationError; + upstreamSubscriber.cancel(); + receiver.onError(t); + return; + } + final int streamId = streamIdSupplier.nextStreamId(receivers); this.streamId = streamId; @@ -712,7 +740,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - terminate(e); + serialScheduler.schedule(() -> terminate(e)); } } } @@ -720,7 +748,7 @@ private void tryTerminate(Supplier errorSupplier) { private void tryShutdown() { if (terminationError == null) { if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { - terminate(CLOSED_CHANNEL_EXCEPTION); + serialScheduler.schedule(() -> terminate(CLOSED_CHANNEL_EXCEPTION)); } } } @@ -729,34 +757,30 @@ private void terminate(Throwable e) { connection.dispose(); leaseHandler.dispose(); - synchronized (receivers) { - receivers - .values() - .forEach( - receiver -> { - try { - receiver.onError(e); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } + receivers + .values() + .forEach( + receiver -> { + try { + receiver.onError(e); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); } - }); - } - synchronized (senders) { - senders - .values() - .forEach( - sender -> { - try { - sender.cancel(); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } + } + }); + senders + .values() + .forEach( + sender -> { + try { + sender.cancel(); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); } - }); - } + } + }); senders.clear(); receivers.clear(); sendProcessor.dispose(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index 1ba75f75a..85995298a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -62,11 +62,13 @@ import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; @@ -90,6 +92,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; @@ -976,24 +979,66 @@ public static Stream streamIdRacingCases() { (BiFunction>) (r, p) -> r.socket.requestStream(p), (BiFunction>) - (r, p) -> r.socket.requestChannel(Flux.just(p))), + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }), Arguments.of( (BiFunction>) - (r, p) -> r.socket.requestChannel(Flux.just(p)), + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }, (BiFunction>) (r, p) -> r.socket.fireAndForget(p))); } - public int sendRequestResponse(Publisher response) { - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameCodec.encodeNextCompleteReleasingPayload( - rule.alloc(), streamId, EmptyPayload.INSTANCE)); - verify(sub).onNext(any(Payload.class)); - verify(sub).onComplete(); - return streamId; + @ParameterizedTest + @MethodSource("streamIdRacingCases") + @SuppressWarnings({"rawtypes", "unchecked"}) + public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( + BiFunction> interaction1, + BiFunction> interaction2) { + for (int i = 1; i < 10000; i++) { + Payload payload1 = ByteBufPayload.create("test"); + Payload payload2 = ByteBufPayload.create("test"); + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + Publisher publisher1 = interaction1.apply(rule, payload1); + Publisher publisher2 = interaction2.apply(rule, payload2); + RaceTestUtils.race( + () -> rule.socket.dispose(), + () -> + RaceTestUtils.race( + () -> publisher1.subscribe(assertSubscriber1), + () -> publisher2.subscribe(assertSubscriber2), + Schedulers.parallel()), + Schedulers.parallel()); + + assertSubscriber1.await().assertTerminated().assertError(ClosedChannelException.class); + assertSubscriber2.await().assertTerminated().assertError(ClosedChannelException.class); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(payload1.refCnt()).isZero(); + Assertions.assertThat(payload2.refCnt()).isZero(); + } } public static class ClientSocketRule extends AbstractSocketRule {