From 03e0e4b0ce1c7dcd3158f40da2c5d72c8048ced7 Mon Sep 17 00:00:00 2001 From: Ryland Degnan Date: Mon, 17 Sep 2018 14:31:19 -0700 Subject: [PATCH 1/2] Allow Frame to be recycled when content is retained --- .../src/main/java/io/rsocket/Frame.java | 70 ++++++------------- .../java/io/rsocket/util/ByteBufPayload.java | 11 +-- .../java/io/rsocket/util/DefaultPayload.java | 15 ++-- .../FragmentationDuplexConnectionTest.java | 20 +++--- .../integration/InteractionsLoadTest.java | 4 +- 5 files changed, 48 insertions(+), 72 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/Frame.java b/rsocket-core/src/main/java/io/rsocket/Frame.java index 997255719..6f400b78f 100644 --- a/rsocket-core/src/main/java/io/rsocket/Frame.java +++ b/rsocket-core/src/main/java/io/rsocket/Frame.java @@ -19,6 +19,7 @@ import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; import io.netty.buffer.*; +import io.netty.util.AbstractReferenceCounted; import io.netty.util.IllegalReferenceCountException; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; @@ -33,7 +34,6 @@ import io.rsocket.frame.VersionFlyweight; import io.rsocket.framing.FrameType; import java.nio.charset.StandardCharsets; -import java.util.Objects; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,7 +43,7 @@ * *

This provides encoding, decoding and field accessors. */ -public class Frame implements Payload, ByteBufHolder { +public class Frame extends AbstractReferenceCounted implements Payload, ByteBufHolder { private static final Recycler RECYCLER = new Recycler() { protected Frame newObject(Handle handle) { @@ -58,12 +58,6 @@ private Frame(final Handle handle) { this.handle = handle; } - /** Clear and recycle this instance. */ - private void recycle() { - content = null; - handle.recycle(this); - } - /** Return the content which is held by this {@link Frame}. */ @Override public ByteBuf content() { @@ -105,26 +99,17 @@ public Frame replace(ByteBuf content) { return from(content); } - /** - * Returns the reference count of this object. If {@code 0}, it means this object has been - * deallocated. - */ - @Override - public int refCnt() { - return content.refCnt(); - } - /** Increases the reference count by {@code 1}. */ @Override public Frame retain() { - content.retain(); + super.retain(); return this; } /** Increases the reference count by the specified {@code increment}. */ @Override public Frame retain(int increment) { - content.retain(increment); + super.retain(increment); return this; } @@ -151,35 +136,13 @@ public Frame touch(@Nullable Object hint) { } /** - * Decreases the reference count by {@code 1} and deallocates this object if the reference count - * reaches at {@code 0}. - * - * @return {@code true} if and only if the reference count became {@code 0} and this object has - * been deallocated - */ - @Override - public boolean release() { - if (content != null && content.release()) { - recycle(); - return true; - } - return false; - } - - /** - * Decreases the reference count by the specified {@code decrement} and deallocates this object if - * the reference count reaches at {@code 0}. - * - * @return {@code true} if and only if the reference count became {@code 0} and this object has - * been deallocated + * Called once {@link #refCnt()} is equals 0. */ @Override - public boolean release(int decrement) { - if (content != null && content.release(decrement)) { - recycle(); - return true; - } - return false; + protected void deallocate() { + content.release(); + content = null; + handle.recycle(this); } /** @@ -239,6 +202,7 @@ public int flags() { */ public static Frame from(final ByteBuf content) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = content; return frame; @@ -281,6 +245,7 @@ public static Frame from( final ByteBuf data = payload.sliceData(); final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( SetupFrameFlyweight.computeFrameLength( @@ -347,6 +312,7 @@ public static Frame from(int streamId, final Throwable throwable, ByteBuf dataBu final int code = ErrorFrameFlyweight.errorCodeFromException(throwable); final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( ErrorFrameFlyweight.computeFrameLength(dataBuffer.readableBytes())); @@ -378,6 +344,7 @@ private Lease() {} public static Frame from(int ttl, int numberOfRequests, ByteBuf metadata) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( LeaseFrameFlyweight.computeFrameLength(metadata.readableBytes())); @@ -411,6 +378,7 @@ public static Frame from(int streamId, int requestN) { } final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer(RequestNFrameFlyweight.computeFrameLength()); frame.content.writerIndex(RequestNFrameFlyweight.encode(frame.content, streamId, requestN)); return frame; @@ -438,6 +406,7 @@ public static Frame from(int streamId, FrameType type, Payload payload, int init final ByteBuf data = payload.sliceData(); final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( RequestFrameFlyweight.computeFrameLength( @@ -464,6 +433,7 @@ public static Frame from(int streamId, FrameType type, Payload payload, int init public static Frame from(int streamId, FrameType type, int flags) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer(RequestFrameFlyweight.computeFrameLength(type, null, 0)); frame.content.writerIndex( @@ -480,6 +450,7 @@ public static Frame from( int initialRequestN, int flags) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( RequestFrameFlyweight.computeFrameLength( @@ -543,6 +514,7 @@ public static Frame from(int streamId, FrameType type, Payload payload, int flag public static Frame from( int streamId, FrameType type, @Nullable ByteBuf metadata, ByteBuf data, int flags) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( FrameHeaderFlyweight.computeFrameHeaderLength( @@ -559,6 +531,7 @@ private Cancel() {} public static Frame from(int streamId) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.CANCEL, null, 0)); @@ -575,6 +548,7 @@ private Keepalive() {} public static Frame from(ByteBuf data, boolean respond) { final Frame frame = RECYCLER.get(); + frame.setRefCnt(1); frame.content = ByteBufAllocator.DEFAULT.buffer( KeepaliveFrameFlyweight.computeFrameLength(data.readableBytes())); @@ -611,12 +585,12 @@ public boolean equals(Object o) { return false; } final Frame frame = (Frame) o; - return Objects.equals(content, frame.content); + return content.equals(frame.content()); } @Override public int hashCode() { - return Objects.hash(content); + return content.hashCode(); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index d1f34589e..e1c134000 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -22,6 +22,7 @@ import io.netty.buffer.Unpooled; import io.netty.util.AbstractReferenceCounted; import io.netty.util.Recycler; +import io.netty.util.Recycler.Handle; import io.rsocket.Payload; import java.nio.ByteBuffer; import java.nio.CharBuffer; @@ -36,11 +37,11 @@ protected ByteBufPayload newObject(Handle handle) { } }; - private final Recycler.Handle handle; + private final Handle handle; private ByteBuf data; private ByteBuf metadata; - private ByteBufPayload(final Recycler.Handle handle) { + private ByteBufPayload(final Handle handle) { this.handle = handle; } @@ -168,12 +169,12 @@ public static Payload create(ByteBuf data) { public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); payload.setRefCnt(1); - payload.data = data.retain(); - payload.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata.retain(); + payload.data = data; + payload.metadata = metadata; return payload; } public static Payload create(Payload payload) { - return create(payload.sliceData(), payload.hasMetadata() ? payload.sliceMetadata() : null); + return create(payload.sliceData().retain(), payload.hasMetadata() ? payload.sliceMetadata().retain() : null); } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index 54cc53cd1..a21bf8da1 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -154,14 +154,15 @@ public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { return new DefaultPayload(data, metadata); } - public static Payload create(Payload payload) { - return create( - copy(payload.sliceData()), payload.hasMetadata() ? copy(payload.sliceMetadata()) : null); + public static Payload create(ByteBuf data) { + return create(data, null); } - private static ByteBuffer copy(ByteBuf byteBuf) { - byte[] contents = new byte[byteBuf.readableBytes()]; - byteBuf.readBytes(contents); - return ByteBuffer.wrap(contents); + public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { + return create(data.nioBuffer(), metadata == null ? null : metadata.nioBuffer()); + } + + public static Payload create(Payload payload) { + return create(Unpooled.copiedBuffer(payload.sliceData()), payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index a729f7856..3e6d2c2a7 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -184,7 +184,7 @@ void reassembleNonFragment() { toAbstractionLeakingFrame( DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, (ByteBuf) null, null)); - when(delegate.receive()).thenReturn(Flux.just(frame)); + when(delegate.receive()).thenReturn(Flux.just(frame.retain())); when(delegate.onClose()).thenReturn(Mono.never()); new FragmentationDuplexConnection(DEFAULT, delegate, 2) @@ -199,7 +199,7 @@ void reassembleNonFragment() { void reassembleNonFragmentableFrame() { Frame frame = toAbstractionLeakingFrame(DEFAULT, 1, createTestCancelFrame()); - when(delegate.receive()).thenReturn(Flux.just(frame)); + when(delegate.receive()).thenReturn(Flux.just(frame.retain())); when(delegate.onClose()).thenReturn(Mono.never()); new FragmentationDuplexConnection(DEFAULT, delegate, 2) @@ -232,7 +232,7 @@ void sendData() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())) @@ -251,7 +251,7 @@ void sendEqualToMaxFragmentLength() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); @@ -266,7 +266,7 @@ void sendFragment() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); @@ -281,7 +281,7 @@ void sendLessThanMaxFragmentLength() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); @@ -310,7 +310,7 @@ void sendMetadata() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())) @@ -354,7 +354,7 @@ void sendMetadataAndData() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())) @@ -373,7 +373,7 @@ void sendNonFragmentable() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); @@ -398,7 +398,7 @@ void sendZeroMaxFragmentLength() { when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 0).sendOne(frame); + new FragmentationDuplexConnection(DEFAULT, delegate, 0).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 68d4f3fd3..41bc197e0 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -19,7 +19,7 @@ public class InteractionsLoadTest { @Test - @SlowTest + //@SlowTest public void channel() { TcpServerTransport serverTransport = TcpServerTransport.create(0); @@ -80,7 +80,7 @@ public Flux requestChannel(Publisher payloads) { if (!data.equals("foo")) { throw new IllegalStateException("Channel Server Bad message: " + data); } - return DefaultPayload.create(DefaultPayload.create("bar")); + return DefaultPayload.create("bar"); }); } From 73b9fd94400781e59bac7fe96680f464a7cfabf0 Mon Sep 17 00:00:00 2001 From: Ryland Degnan Date: Mon, 21 May 2018 08:52:29 -0700 Subject: [PATCH 2/2] Ensure sendProcessor is disposed --- .../main/java/io/rsocket/AbstractRSocket.java | 4 + .../src/main/java/io/rsocket/Frame.java | 4 +- .../java/io/rsocket/KeepAliveHandler.java | 18 +- .../main/java/io/rsocket/RSocketClient.java | 353 +++++++++--------- .../main/java/io/rsocket/RSocketFactory.java | 8 +- .../main/java/io/rsocket/RSocketServer.java | 94 +++-- .../rsocket/internal/UnboundedProcessor.java | 5 + .../java/io/rsocket/util/ByteBufPayload.java | 4 +- .../java/io/rsocket/util/DefaultPayload.java | 4 +- .../integration/InteractionsLoadTest.java | 2 +- 10 files changed, 238 insertions(+), 258 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java index d55885a83..c099a3120 100644 --- a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java @@ -31,16 +31,19 @@ public abstract class AbstractRSocket implements RSocket { @Override public Mono fireAndForget(Payload payload) { + payload.release(); return Mono.error(new UnsupportedOperationException("Fire and forget not implemented.")); } @Override public Mono requestResponse(Payload payload) { + payload.release(); return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); } @Override public Flux requestStream(Payload payload) { + payload.release(); return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); } @@ -51,6 +54,7 @@ public Flux requestChannel(Publisher payloads) { @Override public Mono metadataPush(Payload payload) { + payload.release(); return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); } diff --git a/rsocket-core/src/main/java/io/rsocket/Frame.java b/rsocket-core/src/main/java/io/rsocket/Frame.java index 6f400b78f..a8e5cf980 100644 --- a/rsocket-core/src/main/java/io/rsocket/Frame.java +++ b/rsocket-core/src/main/java/io/rsocket/Frame.java @@ -135,9 +135,7 @@ public Frame touch(@Nullable Object hint) { return this; } - /** - * Called once {@link #refCnt()} is equals 0. - */ + /** Called once {@link #refCnt()} is equals 0. */ @Override protected void deallocate() { content.release(); diff --git a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java index eac560dfa..5ca1d4c7e 100644 --- a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java @@ -8,11 +8,10 @@ import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.UnicastProcessor; -abstract class KeepAliveHandler { +abstract class KeepAliveHandler implements Disposable { private final KeepAlive keepAlive; private final UnicastProcessor sent = UnicastProcessor.create(); private final MonoProcessor timeout = MonoProcessor.create(); - private final Flux interval; private Disposable intervalDisposable; private volatile long lastReceivedMillis; @@ -26,20 +25,17 @@ static KeepAliveHandler ofClient(KeepAlive keepAlive) { private KeepAliveHandler(KeepAlive keepAlive) { this.keepAlive = keepAlive; - this.interval = Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod())); - } - - public void start() { this.lastReceivedMillis = System.currentTimeMillis(); - intervalDisposable = interval.subscribe(v -> onIntervalTick()); + this.intervalDisposable = + Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod())) + .subscribe(v -> onIntervalTick()); } - public void stop() { + @Override + public void dispose() { sent.onComplete(); timeout.onComplete(); - if (intervalDisposable != null) { - intervalDisposable.dispose(); - } + intervalDisposable.dispose(); } public void receive(Frame keepAliveFrame) { diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index 8ce6ede18..44aedbf50 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -37,7 +37,6 @@ class RSocketClient implements RSocket { private final Function frameDecoder; private final Consumer errorConsumer; private final StreamIdSupplier streamIdSupplier; - private final MonoProcessor started; private final NonBlockingHashMapLong senders; private final NonBlockingHashMapLong> receivers; private final UnboundedProcessor sendProcessor; @@ -52,6 +51,7 @@ class RSocketClient implements RSocket { this( connection, frameDecoder, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0); } + /*client requester*/ RSocketClient( DuplexConnection connection, @@ -65,20 +65,26 @@ class RSocketClient implements RSocket { this.frameDecoder = frameDecoder; this.errorConsumer = errorConsumer; this.streamIdSupplier = streamIdSupplier; - this.started = MonoProcessor.create(); this.senders = new NonBlockingHashMapLong<>(256); this.receivers = new NonBlockingHashMapLong<>(256); // DO NOT Change the order here. The Send processor must be subscribed to before receiving this.sendProcessor = new UnboundedProcessor<>(); + connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer); + + connection + .send(sendProcessor) + .doFinally(this::handleSendProcessorCancel) + .subscribe(null, this::handleSendProcessorError); + + connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); + if (!Duration.ZERO.equals(tickPeriod)) { this.keepAliveHandler = KeepAliveHandler.ofClient( new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks)); - started.doOnTerminate(() -> keepAliveHandler.start()).subscribe(); - keepAliveHandler .timeout() .subscribe( @@ -92,18 +98,6 @@ class RSocketClient implements RSocket { } else { keepAliveHandler = null; } - - connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer); - - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); - - connection - .receive() - .doOnSubscribe(subscription -> started.onComplete()) - .subscribe(this::handleIncomingFrames, errorConsumer); } private void handleSendProcessorError(Throwable t) { @@ -140,17 +134,14 @@ private void handleSendProcessorCancel(SignalType t) { @Override public Mono fireAndForget(Payload payload) { - Mono defer = - Mono.fromRunnable( - () -> { - final int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_FNF, payload, 1); - payload.release(); - sendProcessor.onNext(requestFrame); - }); - - return started.then(defer); + return Mono.fromRunnable( + () -> { + final int streamId = streamIdSupplier.nextStreamId(); + final Frame requestFrame = + Frame.Request.from(streamId, FrameType.REQUEST_FNF, payload, 1); + payload.release(); + sendProcessor.onNext(requestFrame); + }); } @Override @@ -170,15 +161,12 @@ public Flux requestChannel(Publisher payloads) { @Override public Mono metadataPush(Payload payload) { - Mono defer = - Mono.fromRunnable( - () -> { - final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); - payload.release(); - sendProcessor.onNext(requestFrame); - }); - - return started.then(defer); + return Mono.fromRunnable( + () -> { + final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); + payload.release(); + sendProcessor.onNext(requestFrame); + }); } @Override @@ -202,159 +190,155 @@ public Mono onClose() { } public Flux handleRequestStream(final Payload payload) { - return started.thenMany( - Flux.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - - UnicastProcessor receiver = UnicastProcessor.create(); - receivers.put(streamId, receiver); - - AtomicBoolean first = new AtomicBoolean(false); - - return receiver - .doOnRequest( - l -> { - if (first.compareAndSet(false, true) && !receiver.isDisposed()) { - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l); - payload.release(); - sendProcessor.onNext(requestFrame); - } else if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.RequestN.from(streamId, l)); - } - sendProcessor.drain(); - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - }); - })); + return Flux.defer( + () -> { + int streamId = streamIdSupplier.nextStreamId(); + + UnicastProcessor receiver = UnicastProcessor.create(); + receivers.put(streamId, receiver); + + AtomicBoolean first = new AtomicBoolean(false); + + return receiver + .doOnRequest( + n -> { + if (first.compareAndSet(false, true) && !receiver.isDisposed()) { + final Frame requestFrame = + Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, n); + payload.release(); + sendProcessor.onNext(requestFrame); + } else if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.RequestN.from(streamId, n)); + } + sendProcessor.drain(); + }) + .doOnError( + t -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Error.from(streamId, t)); + } + }) + .doOnCancel( + () -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Cancel.from(streamId)); + } + }) + .doFinally( + s -> { + receivers.remove(streamId); + }); + }); } private Mono handleRequestResponse(final Payload payload) { - return started.then( - Mono.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_RESPONSE, payload, 1); - payload.release(); - - UnicastProcessor receiver = UnicastProcessor.create(); - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestFrame); - - return receiver - .singleOrEmpty() - .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) - .doOnCancel(() -> sendProcessor.onNext(Frame.Cancel.from(streamId))) - .doFinally( - s -> { - receivers.remove(streamId); - }); - })); + return Mono.defer( + () -> { + int streamId = streamIdSupplier.nextStreamId(); + final Frame requestFrame = + Frame.Request.from(streamId, FrameType.REQUEST_RESPONSE, payload, 1); + payload.release(); + + UnicastProcessor receiver = UnicastProcessor.create(); + receivers.put(streamId, receiver); + + sendProcessor.onNext(requestFrame); + + return receiver + .singleOrEmpty() + .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) + .doOnCancel(() -> sendProcessor.onNext(Frame.Cancel.from(streamId))) + .doFinally( + s -> { + receivers.remove(streamId); + }); + }); } private Flux handleChannel(Flux request) { - return started.thenMany( - Flux.defer( - () -> { - final UnicastProcessor receiver = UnicastProcessor.create(); - final int streamId = streamIdSupplier.nextStreamId(); - final AtomicBoolean firstRequest = new AtomicBoolean(true); - - return receiver - .doOnRequest( - n -> { - if (firstRequest.compareAndSet(true, false)) { - final AtomicBoolean firstPayload = new AtomicBoolean(true); - final Flux requestFrames = - request - .transform( - f -> { - LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); - // Need to set this to one for first the frame - wrapped.increaseRequestLimit(1); - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - - return wrapped; - }) - .map( - payload -> { - final Frame requestFrame; - if (firstPayload.compareAndSet(true, false)) { - requestFrame = - Frame.Request.from( - streamId, FrameType.REQUEST_CHANNEL, payload, n); - } else { - requestFrame = - Frame.PayloadFrame.from( - streamId, FrameType.NEXT, payload); - } - payload.release(); - return requestFrame; - }) - .doOnComplete( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - Frame.PayloadFrame.from( - streamId, FrameType.COMPLETE)); - } - if (firstPayload.get()) { - receiver.onComplete(); - } - }); - - requestFrames.subscribe( - sendProcessor::onNext, - t -> { - errorConsumer.accept(t); - receiver.dispose(); - }); - } else { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.RequestN.from(streamId, n)); - } - } - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - LimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - }); - })); + return Flux.defer( + () -> { + final UnicastProcessor receiver = UnicastProcessor.create(); + final int streamId = streamIdSupplier.nextStreamId(); + final AtomicBoolean firstRequest = new AtomicBoolean(true); + + return receiver + .doOnRequest( + n -> { + if (firstRequest.compareAndSet(true, false)) { + final AtomicBoolean firstPayload = new AtomicBoolean(true); + final Flux requestFrames = + request + .transform( + f -> { + LimitableRequestPublisher wrapped = + LimitableRequestPublisher.wrap(f); + // Need to set this to one for first the frame + wrapped.increaseRequestLimit(1); + senders.put(streamId, wrapped); + receivers.put(streamId, receiver); + + return wrapped; + }) + .map( + payload -> { + final Frame requestFrame; + if (firstPayload.compareAndSet(true, false)) { + requestFrame = + Frame.Request.from( + streamId, FrameType.REQUEST_CHANNEL, payload, n); + } else { + requestFrame = + Frame.PayloadFrame.from( + streamId, FrameType.NEXT, payload); + } + payload.release(); + return requestFrame; + }) + .doOnComplete( + () -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext( + Frame.PayloadFrame.from(streamId, FrameType.COMPLETE)); + } + if (firstPayload.get()) { + receiver.onComplete(); + } + }); + + requestFrames.subscribe( + sendProcessor::onNext, + t -> { + errorConsumer.accept(t); + receiver.dispose(); + }); + } else { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.RequestN.from(streamId, n)); + } + } + }) + .doOnError( + t -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Error.from(streamId, t)); + } + }) + .doOnCancel( + () -> { + if (contains(streamId) && !receiver.isDisposed()) { + sendProcessor.onNext(Frame.Cancel.from(streamId)); + } + }) + .doFinally( + s -> { + receivers.remove(streamId); + LimitableRequestPublisher sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); + } + }); + }); } private boolean contains(int streamId) { @@ -363,7 +347,7 @@ private boolean contains(int streamId) { protected void cleanup() { if (keepAliveHandler != null) { - keepAliveHandler.stop(); + keepAliveHandler.dispose(); } try { for (UnicastProcessor subscriber : receivers.values()) { @@ -375,6 +359,7 @@ protected void cleanup() { } finally { senders.clear(); receivers.clear(); + sendProcessor.dispose(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 3926fe4cd..06017ca1d 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -253,7 +253,7 @@ public Mono start() { } public static class ServerRSocketFactory { - private Supplier acceptor; + private SocketAcceptor acceptor; private Function frameDecoder = DefaultPayload::create; private Consumer errorConsumer = Throwable::printStackTrace; private int mtu = 0; @@ -277,11 +277,6 @@ public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { } public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = () -> acceptor; - return ServerStart::new; - } - - public ServerTransportAcceptor acceptor(Supplier acceptor) { this.acceptor = acceptor; return ServerStart::new; } @@ -357,7 +352,6 @@ private Mono processSetupFrame( RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient); return acceptor - .get() .accept(setupPayload, wrappedRSocketClient) .doOnNext( unwrappedServerSocket -> { diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index 9ec273bf6..95933c6e2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -41,13 +41,11 @@ class RSocketServer implements RSocket { private final RSocket requestHandler; private final Function frameDecoder; private final Consumer errorConsumer; - private final MonoProcessor started; private final NonBlockingHashMapLong sendingSubscriptions; private final NonBlockingHashMapLong> channelProcessors; private final UnboundedProcessor sendProcessor; - private Disposable receiveDisposable; private KeepAliveHandler keepAliveHandler; /*client responder*/ @@ -58,6 +56,7 @@ class RSocketServer implements RSocket { Consumer errorConsumer) { this(connection, requestHandler, frameDecoder, errorConsumer, 0, 0); } + /*server responder*/ RSocketServer( DuplexConnection connection, @@ -72,18 +71,31 @@ class RSocketServer implements RSocket { this.errorConsumer = errorConsumer; this.sendingSubscriptions = new NonBlockingHashMapLong<>(); this.channelProcessors = new NonBlockingHashMapLong<>(); - this.started = MonoProcessor.create(); // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections this.sendProcessor = new UnboundedProcessor<>(); + connection + .send(sendProcessor) + .doFinally(this::handleSendProcessorCancel) + .subscribe(null, this::handleSendProcessorError); + + Disposable receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); + + this.connection + .onClose() + .doFinally( + s -> { + cleanup(); + receiveDisposable.dispose(); + }) + .subscribe(null, errorConsumer); + if (tickPeriod != 0) { keepAliveHandler = KeepAliveHandler.ofServer(new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout)); - started.doOnTerminate(() -> keepAliveHandler.start()).subscribe(); - keepAliveHandler .timeout() .subscribe( @@ -97,26 +109,6 @@ class RSocketServer implements RSocket { } else { keepAliveHandler = null; } - - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); - - this.receiveDisposable = - connection - .receive() - .doOnSubscribe(subscription -> started.onComplete()) - .subscribe(this::handleFrame, errorConsumer); - - this.connection - .onClose() - .doFinally( - s -> { - cleanup(); - receiveDisposable.dispose(); - }) - .subscribe(null, errorConsumer); } private void handleSendProcessorError(Throwable t) { @@ -221,12 +213,13 @@ public Mono onClose() { private void cleanup() { if (keepAliveHandler != null) { - keepAliveHandler.stop(); + keepAliveHandler.dispose(); } cleanUpSendingSubscriptions(); cleanUpChannelProcessors(); requestHandler.dispose(); + sendProcessor.dispose(); } private synchronized void cleanUpSendingSubscriptions() { @@ -263,14 +256,14 @@ private void handleFrame(Frame frame) { handleStream(streamId, requestStream(frameDecoder.apply(frame)), initialRequestN(frame)); break; case REQUEST_CHANNEL: - handleChannel(streamId, frame); - break; - case PAYLOAD: - // TODO: Hook in receiving socket. + handleChannel(streamId, frameDecoder.apply(frame), initialRequestN(frame)); break; case METADATA_PUSH: metadataPush(frameDecoder.apply(frame)); break; + case PAYLOAD: + // TODO: Hook in receiving socket. + break; case LEASE: // Lease must not be received here as this is the server end of the socket which sends // leases. @@ -317,12 +310,9 @@ private void handleFrame(Frame frame) { private void handleFireAndForget(int streamId, Mono result) { result + .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe( - null, - errorConsumer, - null, - subscription -> sendingSubscriptions.put(streamId, subscription)); + .subscribe(null, errorConsumer); } private void handleRequestResponse(int streamId, Mono response) { @@ -347,25 +337,29 @@ private void handleRequestResponse(int streamId, Mono response) { private void handleStream(int streamId, Flux response, int initialRequestN) { response - .map( - payload -> { - final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload); - payload.release(); - return frame; - }) - .concatWith(Mono.fromCallable(() -> Frame.PayloadFrame.from(streamId, FrameType.COMPLETE))) .transform( frameFlux -> { - LimitableRequestPublisher frames = LimitableRequestPublisher.wrap(frameFlux); - sendingSubscriptions.put(streamId, frames); - frames.increaseRequestLimit(initialRequestN); - return frames; + LimitableRequestPublisher payloads = + LimitableRequestPublisher.wrap(frameFlux); + sendingSubscriptions.put(streamId, payloads); + payloads.increaseRequestLimit(initialRequestN); + return payloads; }) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe(sendProcessor::onNext, t -> handleError(streamId, t)); + .subscribe( + payload -> { + final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload); + payload.release(); + sendProcessor.onNext(frame); + }, + t -> handleError(streamId, t), + () -> { + final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.COMPLETE); + sendProcessor.onNext(frame); + }); } - private void handleChannel(int streamId, Frame firstFrame) { + private void handleChannel(int streamId, Payload payload, int initialRequestN) { UnicastProcessor frames = UnicastProcessor.create(); channelProcessors.put(streamId, frames); @@ -379,9 +373,9 @@ private void handleChannel(int streamId, Frame firstFrame) { // not chained, as the payload should be enqueued in the Unicast processor before this method // returns // and any later payload can be processed - frames.onNext(frameDecoder.apply(firstFrame)); + frames.onNext(payload); - handleStream(streamId, requestChannel(payloads), initialRequestN(firstFrame)); + handleStream(streamId, requestChannel(payloads), initialRequestN); } private void handleKeepAliveFrame(Frame frame) { diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index 99b86abc0..110ddd90d 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -311,6 +311,11 @@ public int requestFusion(int requestedMode) { return Fuseable.NONE; } + @Override + public void dispose() { + cancel(); + } + @Override public boolean isDisposed() { return cancelled || done; diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index e1c134000..0b33e5e22 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -175,6 +175,8 @@ public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { } public static Payload create(Payload payload) { - return create(payload.sliceData().retain(), payload.hasMetadata() ? payload.sliceMetadata().retain() : null); + return create( + payload.sliceData().retain(), + payload.hasMetadata() ? payload.sliceMetadata().retain() : null); } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index a21bf8da1..71bbf3874 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -163,6 +163,8 @@ public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { } public static Payload create(Payload payload) { - return create(Unpooled.copiedBuffer(payload.sliceData()), payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); + return create( + Unpooled.copiedBuffer(payload.sliceData()), + payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 41bc197e0..1c488db4f 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -19,7 +19,7 @@ public class InteractionsLoadTest { @Test - //@SlowTest + @SlowTest public void channel() { TcpServerTransport serverTransport = TcpServerTransport.create(0);