Skip to content

Commit

Permalink
ensures streams are terminated if racing terminal and new stream
Browse files Browse the repository at this point in the history
Signed-off-by: Oleh Dokuka <shadowgun@i.ua>
  • Loading branch information
OlegDokuka committed May 21, 2020
1 parent 2ca6a9c commit 5431886
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 41 deletions.
82 changes: 53 additions & 29 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java
Expand Up @@ -189,7 +189,7 @@ public void dispose() {

@Override
public boolean isDisposed() {
return onClose.isDisposed();
return terminationError != null;
}

@Override
Expand Down Expand Up @@ -222,6 +222,12 @@ private Mono<Void> handleFireAndForget(Payload payload) {
new IllegalStateException("FireAndForgetMono allows only a single subscriber"));
}

if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
return Mono.error(t);
}

final int streamId = streamIdSupplier.nextStreamId(receivers);
final ByteBuf requestFrame =
RequestFireAndForgetFrameCodec.encodeReleasingPayload(
Expand Down Expand Up @@ -270,6 +276,13 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
receiver.onError(t);
return;
}

int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -335,6 +348,13 @@ private Flux<Payload> handleRequestStream(final Payload payload) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
receiver.onError(t);
return;
}

int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -477,6 +497,14 @@ protected void hookFinally(SignalType type) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
initialPayload.release();
final Throwable t = terminationError;
upstreamSubscriber.cancel();
receiver.onError(t);
return;
}

final int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -712,15 +740,15 @@ private void tryTerminate(Supplier<Throwable> errorSupplier) {
if (terminationError == null) {
Throwable e = errorSupplier.get();
if (TERMINATION_ERROR.compareAndSet(this, null, e)) {
terminate(e);
serialScheduler.schedule(() -> terminate(e));
}
}
}

private void tryShutdown() {
if (terminationError == null) {
if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) {
terminate(CLOSED_CHANNEL_EXCEPTION);
serialScheduler.schedule(() -> terminate(CLOSED_CHANNEL_EXCEPTION));
}
}
}
Expand All @@ -729,34 +757,30 @@ private void terminate(Throwable e) {
connection.dispose();
leaseHandler.dispose();

synchronized (receivers) {
receivers
.values()
.forEach(
receiver -> {
try {
receiver.onError(e);
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
receivers
.values()
.forEach(
receiver -> {
try {
receiver.onError(e);
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
});
}
synchronized (senders) {
senders
.values()
.forEach(
sender -> {
try {
sender.cancel();
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
}
});
senders
.values()
.forEach(
sender -> {
try {
sender.cancel();
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
});
}
}
});
senders.clear();
receivers.clear();
sendProcessor.dispose();
Expand Down
Expand Up @@ -62,11 +62,13 @@
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand All @@ -90,6 +92,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.UnicastProcessor;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.test.util.RaceTestUtils;
Expand Down Expand Up @@ -976,24 +979,66 @@ public static Stream<Arguments> streamIdRacingCases() {
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestStream(p),
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestChannel(Flux.just(p))),
(r, p) -> {
AtomicBoolean subscribed = new AtomicBoolean();
Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
return r.socket
.requestChannel(just)
.doFinally(
__ -> {
if (!subscribed.get()) {
p.release();
}
});
}),
Arguments.of(
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestChannel(Flux.just(p)),
(r, p) -> {
AtomicBoolean subscribed = new AtomicBoolean();
Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
return r.socket
.requestChannel(just)
.doFinally(
__ -> {
if (!subscribed.get()) {
p.release();
}
});
},
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.fireAndForget(p)));
}

public int sendRequestResponse(Publisher<Payload> response) {
Subscriber<Payload> sub = TestSubscriber.create();
response.subscribe(sub);
int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE);
rule.connection.addToReceivedBuffer(
PayloadFrameCodec.encodeNextCompleteReleasingPayload(
rule.alloc(), streamId, EmptyPayload.INSTANCE));
verify(sub).onNext(any(Payload.class));
verify(sub).onComplete();
return streamId;
@ParameterizedTest
@MethodSource("streamIdRacingCases")
@SuppressWarnings({"rawtypes", "unchecked"})
public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests(
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction1,
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction2) {
for (int i = 1; i < 10000; i++) {
Payload payload1 = ByteBufPayload.create("test");
Payload payload2 = ByteBufPayload.create("test");
AssertSubscriber assertSubscriber1 = AssertSubscriber.create();
AssertSubscriber assertSubscriber2 = AssertSubscriber.create();
Publisher<?> publisher1 = interaction1.apply(rule, payload1);
Publisher<?> publisher2 = interaction2.apply(rule, payload2);
RaceTestUtils.race(
() -> rule.socket.dispose(),
() ->
RaceTestUtils.race(
() -> publisher1.subscribe(assertSubscriber1),
() -> publisher2.subscribe(assertSubscriber2),
Schedulers.parallel()),
Schedulers.parallel());

assertSubscriber1.await().assertTerminated().assertError(ClosedChannelException.class);
assertSubscriber2.await().assertTerminated().assertError(ClosedChannelException.class);

Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release);

Assertions.assertThat(payload1.refCnt()).isZero();
Assertions.assertThat(payload2.refCnt()).isZero();
}
}

public static class ClientSocketRule extends AbstractSocketRule<RSocketRequester> {
Expand Down

0 comments on commit 5431886

Please sign in to comment.