Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensures streams are terminated if racing terminal and new stream #848

Merged
merged 1 commit into from May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 53 additions & 29 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java
Expand Up @@ -189,7 +189,7 @@ public void dispose() {

@Override
public boolean isDisposed() {
return onClose.isDisposed();
return terminationError != null;
}

@Override
Expand Down Expand Up @@ -222,6 +222,12 @@ private Mono<Void> handleFireAndForget(Payload payload) {
new IllegalStateException("FireAndForgetMono allows only a single subscriber"));
}

if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
return Mono.error(t);
}

final int streamId = streamIdSupplier.nextStreamId(receivers);
final ByteBuf requestFrame =
RequestFireAndForgetFrameCodec.encodeReleasingPayload(
Expand Down Expand Up @@ -270,6 +276,13 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
receiver.onError(t);
return;
}

int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -335,6 +348,13 @@ private Flux<Payload> handleRequestStream(final Payload payload) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
payload.release();
final Throwable t = terminationError;
receiver.onError(t);
return;
}

int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -477,6 +497,14 @@ protected void hookFinally(SignalType type) {

@Override
void hookOnFirstRequest(long n) {
if (isDisposed()) {
initialPayload.release();
final Throwable t = terminationError;
upstreamSubscriber.cancel();
receiver.onError(t);
return;
}

final int streamId = streamIdSupplier.nextStreamId(receivers);
this.streamId = streamId;

Expand Down Expand Up @@ -712,15 +740,15 @@ private void tryTerminate(Supplier<Throwable> errorSupplier) {
if (terminationError == null) {
Throwable e = errorSupplier.get();
if (TERMINATION_ERROR.compareAndSet(this, null, e)) {
terminate(e);
serialScheduler.schedule(() -> terminate(e));
}
}
}

private void tryShutdown() {
if (terminationError == null) {
if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) {
terminate(CLOSED_CHANNEL_EXCEPTION);
serialScheduler.schedule(() -> terminate(CLOSED_CHANNEL_EXCEPTION));
}
}
}
Expand All @@ -729,34 +757,30 @@ private void terminate(Throwable e) {
connection.dispose();
leaseHandler.dispose();

synchronized (receivers) {
receivers
.values()
.forEach(
receiver -> {
try {
receiver.onError(e);
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
receivers
.values()
.forEach(
receiver -> {
try {
receiver.onError(e);
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
});
}
synchronized (senders) {
senders
.values()
.forEach(
sender -> {
try {
sender.cancel();
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
}
});
senders
.values()
.forEach(
sender -> {
try {
sender.cancel();
} catch (Throwable t) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Dropped exception", t);
}
});
}
}
});
senders.clear();
receivers.clear();
sendProcessor.dispose();
Expand Down
Expand Up @@ -25,20 +25,31 @@
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import reactor.core.Exceptions;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;

public class RSocketReconnectTest {

private Queue<Retry.RetrySignal> retries = new ConcurrentLinkedQueue<>();

@Test
public void shouldBeASharedReconnectableInstanceOfRSocketMono() {
public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Schedulers.onScheduleHook(
"test",
r ->
() -> {
r.run();
latch.countDown();
});
TestClientTransport[] testClientTransport =
new TestClientTransport[] {new TestClientTransport()};
Mono<RSocket> rSocketMono =
Expand All @@ -52,8 +63,10 @@ public void shouldBeASharedReconnectableInstanceOfRSocketMono() {
Assertions.assertThat(rSocket1).isEqualTo(rSocket2);

testClientTransport[0].testConnection().dispose();
Assertions.assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue();
testClientTransport[0] = new TestClientTransport();

System.out.println("here");
RSocket rSocket3 = rSocketMono.block();
RSocket rSocket4 = rSocketMono.block();

Expand Down
111 changes: 94 additions & 17 deletions rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java
Expand Up @@ -62,11 +62,13 @@
import io.rsocket.util.ByteBufPayload;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand All @@ -90,6 +92,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.UnicastProcessor;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.test.util.RaceTestUtils;
Expand Down Expand Up @@ -941,7 +944,7 @@ private static Stream<Arguments> requestNInteractions() {
}

@ParameterizedTest
@MethodSource("streamIdRacingCases")
@MethodSource("streamRacingCases")
public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing(
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction1,
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction2) {
Expand All @@ -956,44 +959,118 @@ public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing(
Assertions.assertThat(rule.connection.getSent())
.extracting(FrameHeaderCodec::streamId)
.containsExactly(i, i + 2);
rule.connection.getSent().forEach(bb -> bb.release());
rule.connection.getSent().clear();
}
}

public static Stream<Arguments> streamIdRacingCases() {
public static Stream<Arguments> streamRacingCases() {
return Stream.of(
Arguments.of(
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.fireAndForget(p),
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestResponse(p)),
(r, p) -> r.socket.requestResponse(p),
REQUEST_FNF,
REQUEST_RESPONSE),
Arguments.of(
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestResponse(p),
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestStream(p)),
(r, p) -> r.socket.requestStream(p),
REQUEST_RESPONSE,
REQUEST_STREAM),
Arguments.of(
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestStream(p),
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestChannel(Flux.just(p))),
(r, p) -> {
AtomicBoolean subscribed = new AtomicBoolean();
Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
return r.socket
.requestChannel(just)
.doFinally(
__ -> {
if (!subscribed.get()) {
p.release();
}
});
},
REQUEST_STREAM,
REQUEST_CHANNEL),
Arguments.of(
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.requestChannel(Flux.just(p)),
(r, p) -> {
AtomicBoolean subscribed = new AtomicBoolean();
Flux<Payload> just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true));
return r.socket
.requestChannel(just)
.doFinally(
__ -> {
if (!subscribed.get()) {
p.release();
}
});
},
(BiFunction<ClientSocketRule, Payload, Publisher<?>>)
(r, p) -> r.socket.fireAndForget(p)));
(r, p) -> r.socket.fireAndForget(p),
REQUEST_CHANNEL,
REQUEST_FNF));
}

public int sendRequestResponse(Publisher<Payload> response) {
Subscriber<Payload> sub = TestSubscriber.create();
response.subscribe(sub);
int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE);
rule.connection.addToReceivedBuffer(
PayloadFrameCodec.encodeNextCompleteReleasingPayload(
rule.alloc(), streamId, EmptyPayload.INSTANCE));
verify(sub).onNext(any(Payload.class));
verify(sub).onComplete();
return streamId;
@ParameterizedTest
@MethodSource("streamRacingCases")
@SuppressWarnings({"rawtypes", "unchecked"})
public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests(
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction1,
BiFunction<ClientSocketRule, Payload, Publisher<?>> interaction2,
FrameType interactionType1,
FrameType interactionType2) {
for (int i = 1; i < 10000; i++) {
Payload payload1 = ByteBufPayload.create("test");
Payload payload2 = ByteBufPayload.create("test");
AssertSubscriber assertSubscriber1 = AssertSubscriber.create();
AssertSubscriber assertSubscriber2 = AssertSubscriber.create();
Publisher<?> publisher1 = interaction1.apply(rule, payload1);
Publisher<?> publisher2 = interaction2.apply(rule, payload2);
RaceTestUtils.race(
() -> rule.socket.dispose(),
() ->
RaceTestUtils.race(
() -> publisher1.subscribe(assertSubscriber1),
() -> publisher2.subscribe(assertSubscriber2),
Schedulers.parallel()),
Schedulers.parallel());

assertSubscriber1.await().assertTerminated();
if (interactionType1 != REQUEST_FNF) {
assertSubscriber1.assertError(ClosedChannelException.class);
} else {
try {
assertSubscriber1.assertError(ClosedChannelException.class);
} catch (Throwable t) {
// fnf call may be completed
assertSubscriber1.assertComplete();
}
}
assertSubscriber2.await().assertTerminated();
if (interactionType2 != REQUEST_FNF) {
assertSubscriber2.assertError(ClosedChannelException.class);
} else {
try {
assertSubscriber2.assertError(ClosedChannelException.class);
} catch (Throwable t) {
// fnf call may be completed
assertSubscriber2.assertComplete();
}
}

Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release);
rule.connection.getSent().clear();

Assertions.assertThat(payload1.refCnt()).isZero();
Assertions.assertThat(payload2.refCnt()).isZero();
}
}

public static class ClientSocketRule extends AbstractSocketRule<RSocketRequester> {
Expand Down