From de75f44e6ab3d782702ff597f36b4cfa732a3b46 Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Sat, 2 May 2020 13:25:38 +0300 Subject: [PATCH] provides ordered stream id issuing Signed-off-by: Oleh Dokuka --- .../io/rsocket/core/RSocketConnector.java | 4 +- .../io/rsocket/core/RSocketRequester.java | 248 +++++++++--------- .../java/io/rsocket/core/RSocketServer.java | 4 +- .../test/java/io/rsocket/TestScheduler.java | 80 ++++++ .../java/io/rsocket/core/KeepAliveTest.java | 7 +- .../io/rsocket/core/RSocketLeaseTest.java | 4 +- .../core/RSocketRequesterSubscribersTest.java | 25 +- .../io/rsocket/core/RSocketRequesterTest.java | 48 +++- .../java/io/rsocket/core/RSocketTest.java | 4 +- .../io/rsocket/core/SetupRejectionTest.java | 6 +- 10 files changed, 292 insertions(+), 138 deletions(-) create mode 100644 rsocket-core/src/test/java/io/rsocket/TestScheduler.java diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index a7eed8c76..4e47109cf 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -42,6 +42,7 @@ import java.util.function.Supplier; import reactor.core.Disposable; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; public class RSocketConnector { @@ -293,7 +294,8 @@ public Mono connect(Supplier transportSupplier) { (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), keepAliveHandler, - requesterLeaseHandler); + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index fabea217b..af3434fae 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -67,6 +67,7 @@ import reactor.core.publisher.Operators; import reactor.core.publisher.SignalType; import reactor.core.publisher.UnicastProcessor; +import reactor.core.scheduler.Scheduler; import reactor.util.concurrent.Queues; /** @@ -105,6 +106,7 @@ class RSocketRequester implements RSocket { private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; private volatile Throwable terminationError; private final MonoProcessor onClose; + private final Scheduler serialScheduler; RSocketRequester( DuplexConnection connection, @@ -115,7 +117,8 @@ class RSocketRequester implements RSocket { int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, - RequesterLeaseHandler leaseHandler) { + RequesterLeaseHandler leaseHandler, + Scheduler serialScheduler) { this.connection = connection; this.allocator = connection.alloc(); this.payloadDecoder = payloadDecoder; @@ -126,6 +129,7 @@ class RSocketRequester implements RSocket { this.senders = new SynchronizedIntObjectHashMap<>(); this.receivers = new SynchronizedIntObjectHashMap<>(); this.onClose = MonoProcessor.create(); + this.serialScheduler = serialScheduler; // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); @@ -208,22 +212,23 @@ private Mono handleFireAndForget(Payload payload) { final AtomicBoolean once = new AtomicBoolean(); - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("FireAndForgetMono allows only a single subscriber")); - } + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("FireAndForgetMono allows only a single subscriber")); + } - final int streamId = streamIdSupplier.nextStreamId(receivers); - final ByteBuf requestFrame = - RequestFireAndForgetFrameFlyweight.encodeReleasingPayload( - allocator, streamId, payload); + final int streamId = streamIdSupplier.nextStreamId(receivers); + final ByteBuf requestFrame = + RequestFireAndForgetFrameFlyweight.encodeReleasingPayload( + allocator, streamId, payload); - sendProcessor.onNext(requestFrame); + sendProcessor.onNext(requestFrame); - return Mono.empty(); - }); + return Mono.empty(); + }) + .subscribeOn(serialScheduler); } private Mono handleRequestResponse(final Payload payload) { @@ -284,6 +289,7 @@ public void hookOnTerminal(SignalType signalType) { receivers.remove(streamId, receiver); } })) + .subscribeOn(serialScheduler) .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); }); } @@ -356,6 +362,7 @@ void hookOnTerminal(SignalType signalType) { receivers.remove(streamId); } })) + .subscribeOn(serialScheduler, false) .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); }); } @@ -392,120 +399,125 @@ private Flux handleChannel(Payload initialPayload, Flux receiver = UnicastProcessor.create(); - return receiver.transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { + return receiver + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { - final BaseSubscriber upstreamSubscriber = - new BaseSubscriber() { + final BaseSubscriber upstreamSubscriber = + new BaseSubscriber() { - boolean first = true; + boolean first = true; - @Override - protected void hookOnSubscribe(Subscription subscription) { - // noops - } + @Override + protected void hookOnSubscribe(Subscription subscription) { + // noops + } - @Override - protected void hookOnNext(Payload payload) { - if (first) { - // need to skip first since we have already sent it - // no need to release it since it was released earlier on the request - // establishment - // phase - first = false; - request(1); - return; - } - if (!PayloadValidationUtils.isValid(mtu, payload)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - errorConsumer.accept(t); - // no need to send any errors. - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - receiver.onError(t); - return; - } - final ByteBuf frame = - PayloadFrameFlyweight.encodeNextReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(frame); - } + @Override + protected void hookOnNext(Payload payload) { + if (first) { + // need to skip first since we have already sent it + // no need to release it since it was released earlier on the + // request + // establishment + // phase + first = false; + request(1); + return; + } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + errorConsumer.accept(t); + // no need to send any errors. + sendProcessor.onNext( + CancelFrameFlyweight.encode(allocator, streamId)); + receiver.onError(t); + return; + } + final ByteBuf frame = + PayloadFrameFlyweight.encodeNextReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnComplete() { + ByteBuf frame = + PayloadFrameFlyweight.encodeComplete(allocator, streamId); + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnError(Throwable t) { + ByteBuf frame = ErrorFrameFlyweight.encode(allocator, streamId, t); + sendProcessor.onNext(frame); + receiver.onError(t); + } - @Override - protected void hookOnComplete() { - ByteBuf frame = PayloadFrameFlyweight.encodeComplete(allocator, streamId); - sendProcessor.onNext(frame); + @Override + protected void hookFinally(SignalType type) { + senders.remove(streamId, this); + } + }; + + @Override + void hookOnFirstRequest(long n) { + final int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + final ByteBuf frame = + RequestChannelFrameFlyweight.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); + + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); + + inboundFlux + .limitRate(Queues.SMALL_BUFFER_SIZE) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; } - @Override - protected void hookOnError(Throwable t) { - ByteBuf frame = ErrorFrameFlyweight.encode(allocator, streamId, t); - sendProcessor.onNext(frame); - receiver.onError(t); + sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + senders.remove(streamId, upstreamSubscriber); + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); } + } - @Override - protected void hookFinally(SignalType type) { - senders.remove(streamId, this); + @Override + void hookOnTerminal(SignalType signalType) { + if (signalType == SignalType.ON_ERROR) { + upstreamSubscriber.cancel(); } - }; - - @Override - void hookOnFirstRequest(long n) { - final int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - final ByteBuf frame = - RequestChannelFrameFlyweight.encodeReleasingPayload( - allocator, streamId, false, n, initialPayload); - - senders.put(streamId, upstreamSubscriber); - receivers.put(streamId, receiver); - - inboundFlux - .limitRate(Queues.SMALL_BUFFER_SIZE) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) - .subscribe(upstreamSubscriber); - - sendProcessor.onNext(frame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - senders.remove(streamId, upstreamSubscriber); - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - if (signalType == SignalType.ON_ERROR) { - upstreamSubscriber.cancel(); - } - receivers.remove(streamId, receiver); - } - - @Override - public void cancel() { - upstreamSubscriber.cancel(); - super.cancel(); - } - })); + receivers.remove(streamId, receiver); + } + + @Override + public void cancel() { + upstreamSubscriber.cancel(); + super.cancel(); + } + })) + .subscribeOn(serialScheduler, false); } private Mono handleMetadataPush(Payload payload) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 19f0c5008..d5d8cee0f 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -39,6 +39,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; public final class RSocketServer { private static final String SERVER_TAG = "server"; @@ -222,7 +223,8 @@ private Mono acceptSetup( setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, - requesterLeaseHandler); + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); diff --git a/rsocket-core/src/test/java/io/rsocket/TestScheduler.java b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java new file mode 100644 index 000000000..7bc98d45d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java @@ -0,0 +1,80 @@ +package io.rsocket; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.scheduler.Scheduler; +import reactor.util.concurrent.Queues; + +/** + * This is an implementation of scheduler which allows task execution on the caller thread or + * scheduling it for thread which are currently working (with "work stealing" behaviour) + */ +public final class TestScheduler implements Scheduler { + + public static final Scheduler INSTANCE = new TestScheduler(); + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(TestScheduler.class, "wip"); + + final Worker sharedWorker = new TestWorker(this); + final Queue tasks = Queues.unboundedMultiproducer().get(); + + private TestScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + tasks.offer(task); + if (WIP.getAndIncrement(this) != 0) { + return Disposables.never(); + } + + int missed = 1; + + for (; ; ) { + for (; ; ) { + Runnable runnable = tasks.poll(); + + if (runnable == null) { + break; + } + + try { + runnable.run(); + } catch (Throwable t) { + Exceptions.throwIfFatal(t); + } + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return Disposables.never(); + } + } + } + + @Override + public Worker createWorker() { + return sharedWorker; + } + + static class TestWorker implements Worker { + + final TestScheduler parent; + + TestWorker(TestScheduler parent) { + this.parent = parent; + } + + @Override + public Disposable schedule(Runnable task) { + return parent.schedule(task); + } + + @Override + public void dispose() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index e8f3f4190..7e465db08 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -23,6 +23,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.rsocket.RSocket; +import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.frame.FrameHeaderFlyweight; @@ -67,7 +68,8 @@ static RSocketState requester(int tickPeriod, int timeout) { tickPeriod, timeout, new DefaultKeepAliveHandler(connection), - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); return new RSocketState(rSocket, errors, allocator, connection); } @@ -94,7 +96,8 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); return new ResumableRSocketState(rSocket, errors, connection, resumableConnection, allocator); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 51f5afc24..04d5fe174 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -28,6 +28,7 @@ import io.netty.buffer.Unpooled; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.frame.FrameHeaderFlyweight; @@ -98,7 +99,8 @@ void setUp() { 0, 0, null, - requesterLeaseHandler); + requesterLeaseHandler, + TestScheduler.INSTANCE); RSocket mockRSocketHandler = mock(RSocket.class); when(mockRSocketHandler.metadataPush(any())).thenReturn(Mono.empty()); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 4a9d907fa..3e7479af3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.RSocket; +import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; @@ -28,7 +29,6 @@ import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; -import java.time.Duration; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; @@ -41,7 +41,6 @@ import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.test.StepVerifier; import reactor.test.util.RaceTestUtils; class RSocketRequesterSubscribersTest { @@ -73,21 +72,25 @@ void setUp() { 0, 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); } @ParameterizedTest @MethodSource("allInteractions") void singleSubscriber(Function> interaction) { Flux response = Flux.from(interaction.apply(rSocketRequester)); - StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) - .thenAwait(Duration.ofMillis(10)) - .expectComplete() - .verify(Duration.ofSeconds(5)); - StepVerifier.withVirtualTime(() -> response.take(Duration.ofMillis(10))) - .thenAwait(Duration.ofMillis(10)) - .expectError(IllegalStateException.class) - .verify(Duration.ofSeconds(5)); + + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + response.subscribe(assertSubscriberA); + response.subscribe(assertSubscriberB); + + connection.addToReceivedBuffer(PayloadFrameFlyweight.encodeComplete(connection.alloc(), 1)); + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index d7cd8c24b..d58d70e5d 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -41,6 +41,7 @@ import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.TestScheduler; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.exceptions.RejectedSetupException; @@ -924,6 +925,50 @@ private static Stream requestNInteractions() { (rule, payload) -> rule.socket.requestChannel(Flux.just(payload)))); } + @ParameterizedTest + @MethodSource("streamIdRacingCases") + public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( + BiFunction> interaction1, + BiFunction> interaction2) { + for (int i = 1; i < 10000; i += 4) { + Payload payload = DefaultPayload.create("test"); + Publisher publisher1 = interaction1.apply(rule, payload); + Publisher publisher2 = interaction2.apply(rule, payload); + RaceTestUtils.race( + () -> publisher1.subscribe(AssertSubscriber.create()), + () -> publisher2.subscribe(AssertSubscriber.create())); + + Assertions.assertThat(rule.connection.getSent()) + .extracting(FrameHeaderFlyweight::streamId) + .containsExactly(i, i + 2); + rule.connection.getSent().clear(); + } + } + + public static Stream streamIdRacingCases() { + return Stream.of( + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + (BiFunction>) + (r, p) -> r.socket.requestResponse(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + (BiFunction>) + (r, p) -> r.socket.requestStream(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p))), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p)), + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p))); + } + public int sendRequestResponse(Publisher response) { Subscriber sub = TestSubscriber.create(); response.subscribe(sub); @@ -948,7 +993,8 @@ protected RSocketRequester newRSocket() { 0, 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); } public int getStreamIdForRequestType(FrameType expectedFrameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 02c3dfca8..48ce150d6 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -25,6 +25,7 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; @@ -492,7 +493,8 @@ public Flux requestChannel(Publisher payloads) { 0, 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); } public void setRequestAcceptor(RSocket requestAcceptor) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index 4d5cdc0d5..388bfffeb 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -64,7 +64,8 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { 0, 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); String errorMsg = "error"; @@ -101,7 +102,8 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { 0, 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); conn.addToReceivedBuffer( ErrorFrameFlyweight.encode(