diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index c18020833..04be7dac2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -16,14 +16,19 @@ package io.rsocket; -import static io.rsocket.util.ExceptionUtil.noStacktrace; - import io.netty.buffer.Unpooled; import io.netty.util.collection.IntObjectHashMap; import io.rsocket.exceptions.ConnectionException; import io.rsocket.exceptions.Exceptions; import io.rsocket.internal.LimitableRequestPublisher; +import io.rsocket.internal.UnboundedProcessor; import io.rsocket.util.PayloadImpl; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import reactor.core.Disposable; +import reactor.core.publisher.*; + +import javax.annotation.Nullable; import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Collection; @@ -32,11 +37,8 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.Disposable; -import reactor.core.publisher.*; + +import static io.rsocket.util.ExceptionUtil.noStacktrace; /** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */ class RSocketClient implements RSocket { @@ -52,7 +54,7 @@ class RSocketClient implements RSocket { private final IntObjectHashMap> receivers; private final AtomicInteger missedAckCounter; - private final FluxProcessor sendProcessor; + private final UnboundedProcessor sendProcessor; private @Nullable Disposable keepAliveSendSub; private volatile long timeLastTickSentMs; @@ -80,8 +82,7 @@ class RSocketClient implements RSocket { this.missedAckCounter = new AtomicInteger(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - this.sendProcessor = EmitterProcessor.create().serialize(); + this.sendProcessor = new UnboundedProcessor<>(); if (!Duration.ZERO.equals(tickPeriod)) { long ackTimeoutMs = ackTimeout.toMillis(); @@ -98,8 +99,15 @@ class RSocketClient implements RSocket { }) .subscribe(); } - - connection.onClose().doFinally(signalType -> cleanup()).doOnError(errorConsumer).subscribe(); + + connection + .onClose() + .doFinally( + signalType -> { + cleanup(); + }) + .doOnError(errorConsumer) + .subscribe(); connection .send(sendProcessor) @@ -205,7 +213,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { - return handleStreamResponse(Flux.from(payloads), FrameType.REQUEST_CHANNEL); + return handleChannel(Flux.from(payloads), FrameType.REQUEST_CHANNEL); } @Override @@ -255,6 +263,7 @@ public Flux handleRequestStream(final Payload payload) { } else if (contains(streamId) && !receiver.isTerminated()) { sendProcessor.onNext(Frame.RequestN.from(streamId, l)); } + sendProcessor.drain(); }) .doOnError( t -> { @@ -268,7 +277,10 @@ public Flux handleRequestStream(final Payload payload) { sendProcessor.onNext(Frame.Cancel.from(streamId)); } }) - .doFinally(s -> removeReceiver(streamId)); + .doFinally( + s -> { + removeReceiver(streamId); + }); })); } @@ -291,11 +303,14 @@ private Mono handleRequestResponse(final Payload payload) { return receiver .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) .doOnCancel(() -> sendProcessor.onNext(Frame.Cancel.from(streamId))) - .doFinally(s -> removeReceiver(streamId)); + .doFinally( + s -> { + removeReceiver(streamId); + }); })); } - private Flux handleStreamResponse(Flux request, FrameType requestType) { + private Flux handleChannel(Flux request, FrameType requestType) { return started.thenMany( Flux.defer( new Supplier>() { @@ -328,6 +343,7 @@ public Flux get() { } if (_firstRequest) { + AtomicBoolean firstPayload = new AtomicBoolean(true); Flux requestFrames = request .transform( @@ -345,19 +361,10 @@ public Flux get() { }) .map( new Function() { - boolean firstPayload = true; @Override public Frame apply(Payload payload) { - boolean _firstPayload = false; - synchronized (this) { - if (firstPayload) { - firstPayload = false; - _firstPayload = true; - } - } - - if (_firstPayload) { + if (firstPayload.compareAndSet(true, false)) { return Frame.Request.from( streamId, requestType, payload, l); } else { @@ -372,6 +379,9 @@ public Frame apply(Payload payload) { sendOneFrame( Frame.PayloadFrame.from( streamId, FrameType.COMPLETE)); + if (firstPayload.get()) { + receiver.onComplete(); + } } }); @@ -522,6 +532,7 @@ private void handleFrame(int streamId, FrameType type, Frame frame) { if (sender != null) { int n = Frame.RequestN.requestN(frame); sender.increaseRequestLimit(n); + sendProcessor.drain(); } break; } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index 4313e8455..3edec97ac 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -16,24 +16,29 @@ package io.rsocket; -import static io.rsocket.Frame.Request.initialRequestN; -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C; -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; - import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.util.collection.IntObjectHashMap; import io.rsocket.exceptions.ApplicationException; import io.rsocket.internal.LimitableRequestPublisher; +import io.rsocket.internal.UnboundedProcessor; import io.rsocket.util.PayloadImpl; -import java.util.Collection; -import java.util.function.Consumer; -import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.Disposable; -import reactor.core.publisher.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.UnicastProcessor; + +import javax.annotation.Nullable; +import java.util.Collection; +import java.util.function.Consumer; + +import static io.rsocket.Frame.Request.initialRequestN; +import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C; +import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; /** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */ class RSocketServer implements RSocket { @@ -45,7 +50,7 @@ class RSocketServer implements RSocket { private final IntObjectHashMap sendingSubscriptions; private final IntObjectHashMap> channelProcessors; - private final FluxProcessor sendProcessor; + private final UnboundedProcessor sendProcessor; private Disposable receiveDisposable; RSocketServer( @@ -58,7 +63,7 @@ class RSocketServer implements RSocket { // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections - this.sendProcessor = EmitterProcessor.create().serialize(); + this.sendProcessor = new UnboundedProcessor<>(); connection .send(sendProcessor) @@ -302,7 +307,10 @@ private Mono handleRequestResponse(int streamId, Mono response) { .doOnError(errorConsumer) .onErrorResume(t -> Mono.just(Frame.Error.from(streamId, t))) .doOnNext(sendProcessor::onNext) - .doFinally(signalType -> removeSubscription(streamId)) + .doFinally( + signalType -> { + removeSubscription(streamId); + }) .then(); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java new file mode 100644 index 000000000..595e69927 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal; + +import io.netty.util.internal.shaded.org.jctools.queues.atomic.MpscGrowableAtomicArrayQueue; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Fuseable; +import reactor.core.publisher.FluxProcessor; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; +import reactor.util.context.Context; + +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +/** + * A Processor implementation that takes a custom queue and allows only a single subscriber. + * + *

The implementation keeps the order of signals. + * + * @param the input and output type + */ +public final class UnboundedProcessor extends FluxProcessor + implements Fuseable.QueueSubscription, Fuseable { + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); + + final Queue queue; + + volatile boolean done; + Throwable error; + volatile CoreSubscriber actual; + volatile boolean cancelled; + volatile int once; + volatile int wip; + volatile long requested; + volatile long processed; + + public UnboundedProcessor() { + this.queue = new MpscGrowableAtomicArrayQueue<>(Queues.SMALL_BUFFER_SIZE, 1 << 24); + } + + @Override + public int getBufferSize() { + return Queues.capacity(this.queue); + } + + void drainRegular(Subscriber a) { + int missed = 1; + + final Queue q = queue; + + for (; ; ) { + + long r = requested; + long e = 0L; + + while (r != e) { + boolean d = done; + + T t = q.poll(); + boolean empty = t == null; + + if (checkTerminated(d, empty, a, q)) { + return; + } + + if (empty) { + break; + } + a.onNext(t); + + e++; + } + + if (r == e) { + if (checkTerminated(done, q.isEmpty(), a, q)) { + return; + } + } + + if (e != 0 && r != Long.MAX_VALUE) { + REQUESTED.addAndGet(this, -e); + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + break; + } + } + } + + public void drain() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + + for (; ; ) { + Subscriber a = actual; + if (a != null) { + + drainRegular(a); + + return; + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + break; + } + } + } + + boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q) { + if (cancelled) { + q.clear(); + actual = null; + return true; + } + if (d && empty) { + Throwable e = error; + actual = null; + if (e != null) { + a.onError(e); + } else { + a.onComplete(); + } + return true; + } + + return false; + } + + @Override + public void onSubscribe(Subscription s) { + if (done || cancelled) { + s.cancel(); + } else { + s.request(Long.MAX_VALUE); + } + } + + @Override + public int getPrefetch() { + return Integer.MAX_VALUE; + } + + @Override + public Context currentContext() { + CoreSubscriber actual = this.actual; + return actual != null ? actual.currentContext() : Context.empty(); + } + + @Override + public void onNext(T t) { + if (done || cancelled) { + Operators.onNextDropped(t, currentContext()); + return; + } + + if (!queue.offer(t)) { + Throwable ex = + Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); + onError(Operators.onOperatorError(null, ex, t, currentContext())); + return; + } + + drain(); + } + + @Override + public void onError(Throwable t) { + if (done || cancelled) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + error = t; + done = true; + + drain(); + } + + @Override + public void onComplete() { + if (done || cancelled) { + return; + } + + done = true; + + drain(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + this.actual = actual; + actual.onSubscribe(this); + if (cancelled) { + this.actual = null; + } else { + drain(); + } + } else { + Operators.error( + actual, + new IllegalStateException("UnboundedProcessor " + "allows only a single Subscriber")); + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + Operators.addCap(REQUESTED, this, n); + drain(); + } + } + + @Override + public void cancel() { + if (cancelled) { + return; + } + cancelled = true; + if (WIP.getAndIncrement(this) == 0) { + queue.clear(); + } + } + + @Override + @Nullable + public T poll() { + return queue.poll(); + } + + @Override + public int size() { + return queue.size(); + } + + @Override + public boolean isEmpty() { + return queue.isEmpty(); + } + + @Override + public void clear() { + queue.clear(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public boolean isDisposed() { + return cancelled || done; + } + + @Override + public boolean isTerminated() { + return done; + } + + @Override + @Nullable + public Throwable getError() { + return error; + } + + @Override + public long downstreamCount() { + return hasDownstreams() ? 1L : 0L; + } + + @Override + public boolean hasDownstreams() { + return actual != null; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java new file mode 100644 index 000000000..88f54b933 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -0,0 +1,87 @@ +package io.rsocket.internal; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; + +public class UnboundedProcessorTest { + @Test + public void testOnNextBeforeSubscribe_10() { + testOnNextBeforeSubscribeN(10); + } + + @Test + public void testOnNextBeforeSubscribe_100() { + testOnNextBeforeSubscribeN(100); + } + + @Test + public void testOnNextBeforeSubscribe_10_000() { + testOnNextBeforeSubscribeN(10_000); + } + + @Test + public void testOnNextBeforeSubscribe_100_000() { + testOnNextBeforeSubscribeN(100_000); + } + + @Test + public void testOnNextBeforeSubscribe_1_000_000() { + testOnNextBeforeSubscribeN(1_000_000); + } + + @Test + public void testOnNextBeforeSubscribe_10_000_000() { + testOnNextBeforeSubscribeN(10_000_000); + } + + public void testOnNextBeforeSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < n; i++) { + processor.onNext(i); + } + + processor.onComplete(); + + long count = processor.count().block(); + + Assert.assertEquals(n, count); + } + + @Test + public void testOnNextAfterSubscribe_10() throws Exception { + testOnNextAfterSubscribeN(10); + } + + @Test + public void testOnNextAfterSubscribe_100() throws Exception { + testOnNextAfterSubscribeN(100); + } + + @Test + public void testOnNextAfterSubscribe_1000() throws Exception { + testOnNextAfterSubscribeN(1000); + } + + public void testOnNextAfterSubscribeN(int n) throws Exception { + CountDownLatch latch = new CountDownLatch(n); + UnboundedProcessor processor = new UnboundedProcessor<>(); + processor + .log() + .doOnNext(integer -> + latch.countDown()) + .subscribe(); + + for (int i = 0; i < n; i++) { + System.out.println("onNexting -> " + i); + processor.onNext(i); + } + + processor.drain(); + + latch.await(); + } + +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java index 5c8e64264..ba08661f3 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -16,14 +16,17 @@ package io.rsocket.test; -import static org.junit.Assert.assertEquals; - import io.rsocket.Payload; import io.rsocket.util.PayloadImpl; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; public abstract class BaseClientServerTest> { @Rule public final T setup = createClientServer(); @@ -54,7 +57,7 @@ public void testPushMetadata10() { assertEquals(0, outputCount); } - @Test(timeout = 10000) + @Test//(timeout = 10000) public void testRequestResponse1() { long outputCount = Flux.range(1, 1) @@ -121,15 +124,24 @@ public void testRequestResponse10_000() { assertEquals(10_000, outputCount); } - + @Test(timeout = 10000) public void testRequestStream() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); - + long count = publisher.take(5).count().block(); - + assertEquals(5, count); } + + @Test(timeout = 10000) + public void testRequestStreamAll() { + Flux publisher = setup.getRSocket().requestStream(testPayload(3)); + + long count = publisher.count().block(); + + assertEquals(10000, count); + } @Test(timeout = 10000) public void testRequestStreamWithRequestN() { @@ -167,7 +179,6 @@ public void testRequestStreamWithDelayedRequestN() { } @Test(timeout = 10000) - @Ignore public void testChannel0() { Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); @@ -196,4 +207,49 @@ public void testChannel3() { assertEquals(3, count); } + + @Test(timeout = 10000) + public void testChannel512() { + Flux payloads = Flux.range(1, 512).map(i -> new PayloadImpl("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertEquals(512, count); + } + + @Test(timeout = 30000) + public void testChannel20_000() { + Flux payloads = Flux.range(1, 20_000).map(i -> new PayloadImpl("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertEquals(20_000, count); + } + + @Test(timeout = 60_000) + public void testChannel200_000() { + Flux payloads = Flux.range(1, 200_000).map(i -> new PayloadImpl("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertEquals(200_000, count); + } + + @Test(timeout = 60_000) + @Ignore + public void testChannel2_000_000() { + AtomicInteger counter = new AtomicInteger(0); + + Flux payloads = + Flux.range(1, 2_000_000) + .map(i -> new PayloadImpl("hello " + i)); + long count = + setup + .getRSocket() + .requestChannel(payloads) + .count() + .block(); + + assertEquals(2_000_000, count); + } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java index a2b8ec334..2d66fef6c 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -65,7 +65,7 @@ public Flux startPingPong(int count, final Recorder histogram) { histogram.recordValue(diff); }); }, - 16)) + 64)) .doOnError(Throwable::printStackTrace); } }