diff --git a/gradle.properties b/gradle.properties index 27abb8cac..9cd597b08 100644 --- a/gradle.properties +++ b/gradle.properties @@ -12,4 +12,4 @@ # limitations under the License. # -version=0.11.16-SNAPSHOT +version=0.12.1-SNAPSHOT diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index 9f1d7ea6b..dcae3b74e 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -127,5 +127,15 @@ public ByteBuf sliceMetadata() { public ByteBuf sliceData() { return SetupFrameFlyweight.data(setupFrame); } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } } } diff --git a/rsocket-core/src/main/java/io/rsocket/Payload.java b/rsocket-core/src/main/java/io/rsocket/Payload.java index 58fab3382..fc130528e 100644 --- a/rsocket-core/src/main/java/io/rsocket/Payload.java +++ b/rsocket-core/src/main/java/io/rsocket/Payload.java @@ -32,8 +32,8 @@ public interface Payload extends ReferenceCounted { boolean hasMetadata(); /** - * Returns the Payload metadata. Always non-null, check {@link #hasMetadata()} to differentiate - * null from "". + * Returns a slice Payload metadata. Always non-null, check {@link #hasMetadata()} to + * differentiate null from "". * * @return payload metadata. */ @@ -46,6 +46,22 @@ public interface Payload extends ReferenceCounted { */ ByteBuf sliceData(); + /** + * Returns the Payloads' data without slicing if possible. This is not safe and editing this could + * effect the payload. It is recommended to call sliceData(). + * + * @return data as a bytebuf or slice of the data + */ + ByteBuf data(); + + /** + * Returns the Payloads' metadata without slicing if possible. This is not safe and editing this + * could effect the payload. It is recommended to call sliceMetadata(). + * + * @return metadata as a bytebuf or slice of the metadata + */ + ByteBuf metadata(); + /** Increases the reference count by {@code 1}. */ @Override Payload retain(); diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index 27a882d01..902d8487e 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -18,6 +18,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; import io.netty.util.collection.IntObjectHashMap; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; @@ -220,7 +221,6 @@ private Mono handleFireAndForget(Payload payload) { false, payload.hasMetadata() ? payload.sliceMetadata().retain() : null, payload.sliceData().retain()); - payload.release(); sendProcessor.onNext(requestFrame); })); @@ -292,12 +292,12 @@ private Mono handleRequestResponse(final Payload payload) { false, payload.sliceMetadata().retain(), payload.sliceData().retain()); + payload.release(); UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); receivers.put(streamId, receiver); sendProcessor.onNext(requestFrame); - return receiver .doOnError( t -> @@ -472,8 +472,10 @@ private void handleIncomingFrames(ByteBuf frame) { } else { handleFrame(streamId, type, frame); } - } finally { frame.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frame); + throw reactor.core.Exceptions.propagate(t); } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 8e8afda0a..1e7b056ca 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -20,7 +20,6 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.exceptions.InvalidSetupException; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.frame.ErrorFrameFlyweight; import io.rsocket.frame.SetupFrameFlyweight; import io.rsocket.frame.VersionFlyweight; @@ -216,7 +215,7 @@ private class StartClient implements Start { public Mono start() { return transportClient .get() - .connect() + .connect(mtu) .flatMap( connection -> { ByteBuf setupFrame = @@ -231,10 +230,6 @@ public Mono start() { setupPayload.sliceMetadata(), setupPayload.sliceData()); - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, plugins); @@ -333,10 +328,6 @@ public Mono start() { .get() .start( connection -> { - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, plugins); @@ -345,7 +336,8 @@ public Mono start() { .receive() .next() .flatMap(setupFrame -> processSetupFrame(multiplexer, setupFrame)); - }); + }, + mtu); } private Mono processSetupFrame( diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index 2b0eadaf2..a184e11a1 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -34,6 +34,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.Disposable; +import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.SignalType; @@ -330,8 +331,10 @@ private void handleFrame(ByteBuf frame) { new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType)); break; } - } finally { ReferenceCountUtil.safeRelease(frame); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frame); + throw Exceptions.propagate(t); } } @@ -345,11 +348,28 @@ private void handleFireAndForget(int streamId, Mono result) { private void handleRequestResponse(int streamId, Mono response) { response .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) - .map(payload -> PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload)) + .map( + payload -> { + ByteBuf byteBuf = null; + try { + byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload); + } catch (Throwable t) { + if (byteBuf != null) { + ReferenceCountUtil.safeRelease(byteBuf); + ReferenceCountUtil.safeRelease(payload); + } + } + payload.release(); + return byteBuf; + }) .switchIfEmpty( Mono.fromCallable(() -> PayloadFrameFlyweight.encodeComplete(allocator, streamId))) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe(t1 -> sendProcessor.onNext(t1), t -> handleError(streamId, t)); + .subscribe( + t1 -> { + sendProcessor.onNext(t1); + }, + t -> handleError(streamId, t)); } private void handleStream(int streamId, Flux response, int initialRequestN) { @@ -364,9 +384,20 @@ private void handleStream(int streamId, Flux response, int initialReque }) .doFinally(signalType -> sendingSubscriptions.remove(streamId)) .subscribe( - payload -> - sendProcessor.onNext( - PayloadFrameFlyweight.encodeNext(allocator, streamId, payload)), + payload -> { + ByteBuf byteBuf = null; + try { + byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload); + } catch (Throwable t) { + if (byteBuf != null) { + ReferenceCountUtil.safeRelease(byteBuf); + ReferenceCountUtil.safeRelease(payload); + } + throw Exceptions.propagate(t); + } + payload.release(); + sendProcessor.onNext(byteBuf); + }, t -> handleError(streamId, t), () -> sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId))); } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index af492508d..023c4e689 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -16,9 +16,18 @@ package io.rsocket.fragmentation; +import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; + import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameType; +import java.util.Objects; import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,139 +39,84 @@ * and Reassembly */ public final class FragmentationDuplexConnection implements DuplexConnection { - public FragmentationDuplexConnection(DuplexConnection connection, int mtu) {} - - @Override - public Mono send(Publisher frames) { - return null; - } - - @Override - public Flux receive() { - return null; - } - - @Override - public Mono onClose() { - return null; - } - - @Override - public void dispose() {} - - /* - private final ByteBufAllocator byteBufAllocator; - + private static final int MIN_MTU_SIZE = 64; + private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); private final DuplexConnection delegate; + private final int mtu; + private final ByteBufAllocator allocator; + private final FrameReassembler frameReassembler; + private final boolean encodeLength; - private final FrameFragmenter frameFragmenter; - - private final IntObjectHashMap frameReassemblers = new IntObjectHashMap<>(); - - */ - /** - * Creates a new instance. - * - * @param delegate the {@link DuplexConnection} to decorate - * @param maxFragmentSize the maximum fragment size - * @throws NullPointerException if {@code delegate} is {@code null} - * @throws IllegalArgumentException if {@code maxFragmentSize} is not {@code positive} - */ - /* - // TODO: Remove once ByteBufAllocators are shared - public FragmentationDuplexConnection(DuplexConnection delegate, int maxFragmentSize) { - this(PooledByteBufAllocator.DEFAULT, delegate, maxFragmentSize); - } - - */ - /** - * Creates a new instance. - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @param delegate the {@link DuplexConnection} to decorate - * @param maxFragmentSize the maximum fragment size. A value of 0 indicates that frames should not - * be fragmented. - * @throws NullPointerException if {@code byteBufAllocator} or {@code delegate} are {@code null} - * @throws IllegalArgumentException if {@code maxFragmentSize} is not {@code positive} - */ - /* public FragmentationDuplexConnection( - ByteBufAllocator byteBufAllocator, DuplexConnection delegate, int maxFragmentSize) { - - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); - this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); - - NumberUtils.requireNonNegative(maxFragmentSize, "maxFragmentSize must be positive"); + DuplexConnection delegate, ByteBufAllocator allocator, int mtu, boolean encodeLength) { + Objects.requireNonNull(delegate, "delegate must not be null"); + Objects.requireNonNull(allocator, "byteBufAllocator must not be null"); + if (mtu < MIN_MTU_SIZE) { + throw new IllegalArgumentException("smallest allowed mtu size is " + MIN_MTU_SIZE + " bytes"); + } + this.encodeLength = encodeLength; + this.allocator = allocator; + this.delegate = delegate; + this.mtu = mtu; + this.frameReassembler = new FrameReassembler(allocator); - this.frameFragmenter = new FrameFragmenter(byteBufAllocator, maxFragmentSize); + delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); + } - delegate - .onClose() - .doFinally( - signalType -> { - Collection values; - synchronized (FragmentationDuplexConnection.this) { - values = frameReassemblers.values(); - } - values.forEach(FrameReassembler::dispose); - }) - .subscribe(); + private boolean shouldFragment(FrameType frameType, int readableBytes) { + return frameType.isFragmentable() && readableBytes > mtu; } @Override - public double availability() { - return delegate.availability(); + public Mono send(Publisher frames) { + return Flux.from(frames).concatMap(this::sendOne).then(); } @Override - public void dispose() { - delegate.dispose(); + public Mono sendOne(ByteBuf frame) { + FrameType frameType = FrameHeaderFlyweight.frameType(frame); + int readableBytes = frame.readableBytes(); + if (shouldFragment(frameType, readableBytes)) { + return delegate.send(fragmentFrame(allocator, mtu, frame, frameType, encodeLength)); + } else { + return delegate.sendOne(encode(frame)); + } } - @Override - public boolean isDisposed() { - return delegate.isDisposed(); + private ByteBuf encode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame).retain(); + } else { + return frame; + } } - @Override - public Mono onClose() { - return delegate.onClose(); + private ByteBuf decode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthFlyweight.frame(frame).retain(); + } else { + return frame; + } } @Override public Flux receive() { return delegate .receive() - .map(AbstractionLeakingFrameUtils::fromAbstractionLeakingFrame) - .concatMap(t2 -> toReassembledFrames(t2.getT1(), t2.getT2())); + .handle( + (byteBuf, sink) -> { + ByteBuf decode = decode(byteBuf); + frameReassembler.reassembleFrame(decode, sink); + }); } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); - - return delegate.send( - Flux.from(frames) - .map(AbstractionLeakingFrameUtils::fromAbstractionLeakingFrame) - .concatMap(t2 -> toFragmentedFrames(t2.getT1(), t2.getT2()))); + public Mono onClose() { + return delegate.onClose(); } - private Flux toFragmentedFrames(int streamId, io.rsocket.framing.Frame frame) { - return this.frameFragmenter - .fragment(frame) - .map(fragment -> toAbstractionLeakingFrame(byteBufAllocator, streamId, fragment)); + @Override + public void dispose() { + delegate.dispose(); } - - private Mono toReassembledFrames(int streamId, io.rsocket.framing.Frame fragment) { - FrameReassembler frameReassembler; - synchronized (this) { - frameReassembler = - frameReassemblers.computeIfAbsent( - streamId, i -> createFrameReassembler(byteBufAllocator)); - } - - return Mono.justOrEmpty(frameReassembler.reassemble(fragment)) - .map(frame -> toAbstractionLeakingFrame(byteBufAllocator, streamId, frame)); - }*/ } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index e9b4de243..d634f7374 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -16,6 +16,16 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.frame.*; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + /** * The implementation of the RSocket fragmentation behavior. * @@ -24,168 +34,208 @@ * and Reassembly */ final class FrameFragmenter { - /* - private final ByteBufAllocator byteBufAllocator; - - private final Logger logger = LoggerFactory.getLogger(this.getClass()); - - private final int maxFragmentSize; - - */ - /** - * Creates a new instance - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @param maxFragmentSize the maximum size of each fragment - */ - /* - FrameFragmenter(ByteBufAllocator byteBufAllocator, int maxFragmentSize) { - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); - this.maxFragmentSize = maxFragmentSize; - } - - */ - /** - * Returns a {@link Flux} of fragments frames - * - * @param frame the {@link ByteBuf} to fragment - * @return a {@link Flux} of fragment frames - * @throws NullPointerException if {@code frame} is {@code null} - */ - /* - public Flux fragment(ByteBuf frame) { - Objects.requireNonNull(frame, "frame must not be null"); - - if (!shouldFragment(frame)) { - logger.debug("Not fragmenting {}", frame); - return Flux.just(frame); - } - - logger.debug("Fragmenting {}", frame); + static Publisher fragmentFrame( + ByteBufAllocator allocator, + int mtu, + final ByteBuf frame, + FrameType frameType, + boolean encodeLength) { + ByteBuf metadata = getMetadata(frame, frameType); + ByteBuf data = getData(frame, frameType); + int streamId = FrameHeaderFlyweight.streamId(frame); return Flux.generate( - () -> new FragmentationState((FragmentableFrame) frame), - this::generate, - FragmentationState::dispose); + new Consumer>() { + boolean first = true; + + @Override + public void accept(SynchronousSink sink) { + ByteBuf byteBuf; + if (first) { + first = false; + byteBuf = + encodeFirstFragment( + allocator, mtu, frame, frameType, streamId, metadata, data); + } else { + byteBuf = encodeFollowsFragment(allocator, mtu, streamId, metadata, data); + } + + sink.next(encode(allocator, byteBuf, encodeLength)); + if (!metadata.isReadable() && !data.isReadable()) { + sink.complete(); + } + } + }) + .doFinally(signalType -> ReferenceCountUtil.safeRelease(frame)); } - private FragmentationState generate(FragmentationState state, SynchronousSink sink) { - int fragmentLength = maxFragmentSize; - - ByteBuf metadata; - if (state.hasReadableMetadata()) { - metadata = state.readMetadataFragment(fragmentLength); - fragmentLength -= metadata.readableBytes(); - } else { - metadata = null; + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + ByteBuf frame, + FrameType frameType, + int streamId, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + int remaining = mtu - FrameHeaderFlyweight.size(); + + // substract the initial request n + switch (frameType) { + case REQUEST_STREAM: + case REQUEST_CHANNEL: + remaining -= Integer.BYTES; + break; + default: } - if (state.hasReadableMetadata()) { - ByteBuf fragment = state.createFrame(byteBufAllocator, false, metadata, null); - logger.debug("Fragment {}", fragment); - - sink.next(fragment); - return state; + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= 3; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); } - ByteBuf data; - data = state.hasReadableData() ? state.readDataFragment(fragmentLength) : null; - - if (state.hasReadableData()) { - ByteBuf fragment = state.createFrame(byteBufAllocator, false, metadata, data); - logger.debug("Fragment {}", fragment); - - sink.next(fragment); - return state; + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); } - ByteBuf fragment = state.createFrame(byteBufAllocator, true, metadata, data); - logger.debug("Final Fragment {}", fragment); - - sink.next(fragment); - sink.complete(); - return state; - } - - private int getFragmentableLength(FragmentableFrame fragmentableFrame) { - return fragmentableFrame.getMetadataLength().orElse(0) + fragmentableFrame.getDataLength(); - } - - private boolean shouldFragment(ByteBuf frame) { - if (maxFragmentSize == 0 || !(frame instanceof FragmentableFrame)) { - return false; + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameFlyweight.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_STREAM: + return RequestStreamFrameFlyweight.encode( + allocator, + streamId, + true, + RequestStreamFrameFlyweight.initialRequestN(frame), + metadataFragment, + dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameFlyweight.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameFlyweight.encode( + allocator, + streamId, + true, + false, + RequestChannelFrameFlyweight.initialRequestN(frame), + metadataFragment, + dataFragment); + // Payload and synthetic types + case PAYLOAD: + return PayloadFrameFlyweight.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + return PayloadFrameFlyweight.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + case NEXT_COMPLETE: + return PayloadFrameFlyweight.encode( + allocator, streamId, true, true, true, metadataFragment, dataFragment); + case COMPLETE: + return PayloadFrameFlyweight.encode( + allocator, streamId, true, true, false, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); } - - FragmentableFrame fragmentableFrame = (FragmentableFrame) frame; - return !fragmentableFrame.isFollowsFlagSet() - && getFragmentableLength(fragmentableFrame) > maxFragmentSize; } - static final class FragmentationState implements Disposable { - - private final FragmentableFrame frame; - - private int dataIndex = 0; - - private boolean initialFragmentCreated = false; - - private int metadataIndex = 0; - - FragmentationState(FragmentableFrame frame) { - this.frame = frame; + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, int mtu, int streamId, ByteBuf metadata, ByteBuf data) { + // subtract the header bytes + int remaining = mtu - FrameHeaderFlyweight.size(); + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= 3; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); } - @Override - public void dispose() { - disposeQuietly(frame); + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); } - ByteBuf createFrame( - ByteBufAllocator byteBufAllocator, - boolean complete, - @Nullable ByteBuf metadata, - @Nullable ByteBuf data) { - - if (initialFragmentCreated) { - return createPayloadFrame(byteBufAllocator, !complete, data == null, metadata, data); - } else { - initialFragmentCreated = true; - return frame.createFragment(byteBufAllocator, metadata, data); - } - } - - boolean hasReadableData() { - return frame.getDataLength() - dataIndex > 0; - } + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameFlyweight.encode( + allocator, streamId, follows, false, true, metadataFragment, dataFragment); + } - boolean hasReadableMetadata() { - Integer metadataLength = frame.getUnsafeMetadataLength(); - return metadataLength != null && metadataLength - metadataIndex > 0; + static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameFlyweight.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameFlyweight.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameFlyweight.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameFlyweight.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameFlyweight.metadata(frame); + break; + default: + throw new IllegalStateException("unsupported fragment type"); + } + return metadata; + } else { + return Unpooled.EMPTY_BUFFER; } + } - ByteBuf readDataFragment(int length) { - int safeLength = min(length, frame.getDataLength() - dataIndex); - - ByteBuf fragment = frame.getUnsafeData().slice(dataIndex, safeLength); - - dataIndex += fragment.readableBytes(); - return fragment; + static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameFlyweight.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameFlyweight.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameFlyweight.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameFlyweight.data(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameFlyweight.data(frame); + break; + default: + throw new IllegalStateException("unsupported fragment type"); } + return data; + } - ByteBuf readMetadataFragment(int length) { - Integer metadataLength = frame.getUnsafeMetadataLength(); - ByteBuf metadata = frame.getUnsafeMetadata(); - - if (metadataLength == null || metadata == null) { - throw new IllegalStateException("Cannot read metadata fragment with no metadata"); - } - - int safeLength = min(length, metadataLength - metadataIndex); - - ByteBuf fragment = metadata.slice(metadataIndex, safeLength); - - metadataIndex += fragment.readableBytes(); - return fragment; + static ByteBuf encode(ByteBufAllocator allocator, ByteBuf frame, boolean encodeLength) { + if (encodeLength) { + return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); + } else { + return frame; } - }*/ + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java index a44883915..1d0ae6792 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java @@ -16,7 +16,19 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.frame.*; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; +import reactor.core.publisher.SynchronousSink; /** * The implementation of the RSocket reassembly behavior. @@ -25,143 +37,252 @@ * href="https://github.com/rsocket/rsocket/blob/master/Protocol.md#fragmentation-and-reassembly">Fragmentation * and Reassembly */ -final class FrameReassembler implements Disposable { +final class FrameReassembler extends AtomicBoolean implements Disposable { + private static final Logger logger = LoggerFactory.getLogger(FrameReassembler.class); + + final IntObjectMap headers; + final IntObjectMap metadata; + final IntObjectMap data; + + private final ByteBufAllocator allocator; + + public FrameReassembler(ByteBufAllocator allocator) { + this.allocator = allocator; + this.headers = new IntObjectHashMap<>(); + this.metadata = new IntObjectHashMap<>(); + this.data = new IntObjectHashMap<>(); + } + @Override - public void dispose() {} + public void dispose() { + if (compareAndSet(false, true)) { + synchronized (FrameReassembler.this) { + for (ByteBuf byteBuf : headers.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + headers.clear(); + + for (ByteBuf byteBuf : metadata.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + metadata.clear(); + + for (ByteBuf byteBuf : data.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + data.clear(); + } + } + } @Override public boolean isDisposed() { - return false; + return get(); } - /* - private static final Recycler RECYCLER = createRecycler(FrameReassembler::new); - private final Handle handle; + synchronized ByteBuf getHeader(int streamId) { + return headers.get(streamId); + } - private ByteBufAllocator byteBufAllocator; + synchronized CompositeByteBuf getMetadata(int streamId) { + CompositeByteBuf byteBuf = metadata.get(streamId); - private ReassemblyState state; + if (byteBuf == null) { + byteBuf = allocator.compositeBuffer(); + metadata.put(streamId, byteBuf); + } - private FrameReassembler(Handle handle) { - this.handle = handle; + return byteBuf; } - @Override - public void dispose() { - if (state != null) { - disposeQuietly(state); - } + synchronized CompositeByteBuf getData(int streamId) { + CompositeByteBuf byteBuf = data.get(streamId); - byteBufAllocator = null; - state = null; - - handle.recycle(this); - } - - */ - /** - * Creates a new instance - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @return the {@code FrameReassembler} - * @throws NullPointerException if {@code byteBufAllocator} is {@code null} - */ - /* - static FrameReassembler createFrameReassembler(ByteBufAllocator byteBufAllocator) { - return RECYCLER.get().setByteBufAllocator(byteBufAllocator); - } - - */ - /** - * Reassembles a frame. If the frame is not a candidate for fragmentation, emits the frame. If - * frame is a candidate for fragmentation, accumulates the content until the final fragment. - * - * @param frame the frame to inspect for reassembly - * @return the reassembled frame if complete, otherwise {@code null} - * @throws NullPointerException if {@code frame} is {@code null} - */ - /* - @Nullable - Frame reassemble(Frame frame) { - Objects.requireNonNull(frame, "frame must not be null"); - - if (!(frame instanceof FragmentableFrame)) { - return frame; + if (byteBuf == null) { + byteBuf = allocator.compositeBuffer(); + data.put(streamId, byteBuf); } - FragmentableFrame fragmentableFrame = (FragmentableFrame) frame; - - if (fragmentableFrame.isFollowsFlagSet()) { - if (state == null) { - state = new ReassemblyState(fragmentableFrame); - } else { - state.accumulate(fragmentableFrame); - } - } else if (state != null) { - state.accumulate(fragmentableFrame); - - Frame reassembledFrame = state.createFrame(byteBufAllocator); - state.dispose(); - state = null; + return byteBuf; + } - return reassembledFrame; - } else { - return fragmentableFrame; - } + synchronized ByteBuf removeHeader(int streamId) { + return headers.remove(streamId); + } - return null; + synchronized CompositeByteBuf removeMetadata(int streamId) { + return metadata.remove(streamId); } - FrameReassembler setByteBufAllocator(ByteBufAllocator byteBufAllocator) { - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); + synchronized CompositeByteBuf removeData(int streamId) { + return data.remove(streamId); + } - return this; + synchronized void putHeader(int streamId, ByteBuf header) { + headers.put(streamId, header); } - static final class ReassemblyState implements Disposable { + void cancelAssemble(int streamId) { + ByteBuf header = removeHeader(streamId); + CompositeByteBuf metadata = removeMetadata(streamId); + CompositeByteBuf data = removeData(streamId); - private ByteBuf data; + if (header != null) { + ReferenceCountUtil.safeRelease(header); + } - private List fragments = new ArrayList<>(); + if (metadata != null) { + ReferenceCountUtil.safeRelease(metadata); + } - private ByteBuf metadata; + if (data != null) { + ReferenceCountUtil.safeRelease(data); + } + } - ReassemblyState(FragmentableFrame fragment) { - accumulate(fragment); + void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int streamId) { + ByteBuf header = removeHeader(streamId); + if (header != null) { + if (FrameHeaderFlyweight.hasMetadata(header)) { + ByteBuf assembledFrame = assembleFrameWithMetadata(frame, streamId, header); + sink.next(assembledFrame); + } else { + ByteBuf data = assembleData(frame, streamId); + ByteBuf assembledFrame = FragmentationFlyweight.encode(allocator, header, data); + sink.next(assembledFrame); + } + } else { + sink.next(frame); } + } - @Override - public void dispose() { - fragments.forEach(Disposable::dispose); + void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { + ByteBuf header = getHeader(streamId); + if (header == null) { + header = frame.copy(frame.readerIndex(), FrameHeaderFlyweight.size()); + + if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { + int i = RequestChannelFrameFlyweight.initialRequestN(frame); + header.writeInt(i); + } + putHeader(streamId, header); } - void accumulate(FragmentableFrame fragment) { - fragments.add(fragment); - metadata = accumulateMetadata(fragment); - data = accumulateData(fragment); + if (FrameHeaderFlyweight.hasMetadata(frame)) { + CompositeByteBuf metadata = getMetadata(streamId); + switch (frameType) { + case REQUEST_FNF: + metadata.addComponents(true, RequestFireAndForgetFrameFlyweight.metadata(frame)); + break; + case REQUEST_STREAM: + metadata.addComponents(true, RequestStreamFrameFlyweight.metadata(frame)); + break; + case REQUEST_RESPONSE: + metadata.addComponents(true, RequestResponseFrameFlyweight.metadata(frame)); + break; + case REQUEST_CHANNEL: + metadata.addComponents(true, RequestChannelFrameFlyweight.metadata(frame)); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata.addComponents(true, PayloadFrameFlyweight.metadata(frame)); + break; + default: + throw new IllegalStateException("unsupported fragment type"); + } } - Frame createFrame(ByteBufAllocator byteBufAllocator) { - FragmentableFrame root = fragments.get(0); - return root.createNonFragment(byteBufAllocator, metadata, data); + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameFlyweight.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameFlyweight.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameFlyweight.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameFlyweight.data(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameFlyweight.data(frame); + break; + default: + throw new IllegalStateException("unsupported fragment type"); } - private ByteBuf accumulateData(FragmentableFrame fragment) { - ByteBuf data = fragment.getUnsafeData(); - return this.data == null ? data.retain() : Unpooled.wrappedBuffer(this.data, data.retain()); + if (data != Unpooled.EMPTY_BUFFER) { + getData(streamId).addComponents(true, data); } + } + + void reassembleFrame(ByteBuf frame, SynchronousSink sink) { + try { + FrameType frameType = FrameHeaderFlyweight.frameType(frame); + int streamId = FrameHeaderFlyweight.streamId(frame); + switch (frameType) { + case CANCEL: + case ERROR: + cancelAssemble(streamId); + default: + } - private @Nullable ByteBuf accumulateMetadata(FragmentableFrame fragment) { - ByteBuf metadata = fragment.getUnsafeMetadata(); + if (!frameType.isFragmentable()) { + sink.next(frame); + return; + } + + boolean hasFollows = FrameHeaderFlyweight.hasFollows(frame); - if (metadata == null) { - return this.metadata; + if (!hasFollows) { + handleNoFollowsFlag(frame, sink, streamId); + } else { + handleFollowsFlag(frame, streamId, frameType); } - return this.metadata == null - ? metadata.retain() - : Unpooled.wrappedBuffer(this.metadata, metadata.retain()); + } catch (Throwable t) { + logger.error("error reassemble frame", t); + sink.error(t); } - }*/ + } + + private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { + ByteBuf metadata; + CompositeByteBuf cm = removeMetadata(streamId); + if (cm != null) { + ByteBuf m = PayloadFrameFlyweight.metadata(frame); + metadata = cm.addComponents(true, m); + } else { + metadata = PayloadFrameFlyweight.metadata(frame); + } + + ByteBuf data = assembleData(frame, streamId); + + return FragmentationFlyweight.encode(allocator, header, metadata, data); + } + + private ByteBuf assembleData(ByteBuf frame, int streamId) { + ByteBuf data; + CompositeByteBuf cd = removeData(streamId); + if (cd != null) { + ByteBuf d = PayloadFrameFlyweight.data(frame); + if (d != null) { + cd.addComponents(true, d); + } + data = cd; + } else { + data = Unpooled.EMPTY_BUFFER; + } + + return data; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java index f07f5f004..6a493fffe 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java @@ -21,9 +21,12 @@ private static void encodeLength(final ByteBuf byteBuf, final int length) { } private static int decodeLength(final ByteBuf byteBuf) { - int length = (byteBuf.readByte() & 0xFF) << 16; - length |= (byteBuf.readByte() & 0xFF) << 8; - length |= byteBuf.readByte() & 0xFF; + byte b = byteBuf.readByte(); + int length = (b & 0xFF) << 16; + byte b1 = byteBuf.readByte(); + length |= (b1 & 0xFF) << 8; + byte b2 = byteBuf.readByte(); + length |= b2 & 0xFF; return length; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java new file mode 100644 index 000000000..d5b3742b5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java @@ -0,0 +1,21 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +/** FragmentationFlyweight is used to re-assemble frames */ +public class FragmentationFlyweight { + public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, ByteBuf data) { + return encode(allocator, header, null, data); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { + if (metadata == null) { + return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); + } else { + return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java index 7dbe8053a..7f03984d8 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java @@ -45,7 +45,7 @@ static ByteBuf encodeStreamZero( return encode(allocator, 0, frameType, flags); } - static ByteBuf encode( + public static ByteBuf encode( final ByteBufAllocator allocator, final int streamId, final FrameType frameType, int flags) { if (!frameType.canHaveMetadata() && ((flags & FLAGS_M) == FLAGS_M)) { throw new IllegalStateException("bad value for metadata flag"); @@ -56,6 +56,10 @@ static ByteBuf encode( return allocator.buffer().writeInt(streamId).writeShort(typeAndFlags); } + public static boolean hasFollows(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_F) == FLAGS_F; + } + public static int streamId(ByteBuf byteBuf) { byteBuf.markReaderIndex(); int streamId = byteBuf.readInt(); @@ -113,7 +117,7 @@ public static void ensureFrameType(final FrameType frameType, ByteBuf byteBuf) { } } - static int size() { + public static int size() { return HEADER_SIZE; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java new file mode 100644 index 000000000..f9ae72ed2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -0,0 +1,108 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; + +public class FrameUtil { + private FrameUtil() {} + + public static String toString(ByteBuf frame) { + FrameType frameType = FrameHeaderFlyweight.frameType(frame); + int streamId = FrameHeaderFlyweight.streamId(frame); + StringBuilder payload = new StringBuilder(); + + payload + .append("\nFrame => Stream ID: ") + .append(streamId) + .append(" Type: ") + .append(frameType) + .append(" Flags: 0b") + .append(Integer.toBinaryString(FrameHeaderFlyweight.flags(frame))) + .append(" Length: " + frame.readableBytes()); + + if (FrameHeaderFlyweight.hasMetadata(frame)) { + payload.append("\nMetadata:\n"); + + ByteBufUtil.appendPrettyHexDump(payload, getMetadata(frame, frameType)); + } + + payload.append("\nData:\n"); + ByteBufUtil.appendPrettyHexDump(payload, getData(frame, frameType)); + + return payload.toString(); + } + + private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameFlyweight.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameFlyweight.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameFlyweight.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameFlyweight.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameFlyweight.metadata(frame); + break; + case METADATA_PUSH: + metadata = MetadataPushFrameFlyweight.metadata(frame); + break; + case SETUP: + metadata = SetupFrameFlyweight.metadata(frame); + break; + case LEASE: + metadata = LeaseFlyweight.metadata(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return metadata.retain(); + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + private static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameFlyweight.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameFlyweight.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameFlyweight.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameFlyweight.data(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameFlyweight.data(frame); + break; + case SETUP: + data = SetupFrameFlyweight.data(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return data.retain(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java index 83f2406dd..4f67d9c72 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java @@ -35,8 +35,8 @@ public static ByteBuf encode( complete, next, 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); + payload.hasMetadata() ? payload.metadata().retain() : null, + payload.data().retain()); } public static ByteBuf encodeNextComplete( @@ -48,8 +48,8 @@ public static ByteBuf encodeNextComplete( true, true, 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); + payload.hasMetadata() ? payload.metadata().retain() : null, + payload.data().retain()); } public static ByteBuf encodeNext(ByteBufAllocator allocator, int streamId, Payload payload) { @@ -60,8 +60,8 @@ public static ByteBuf encodeNext(ByteBufAllocator allocator, int streamId, Paylo false, true, 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); + payload.hasMetadata() ? payload.metadata().retain() : null, + payload.data().retain()); } public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java index fb6ecebb0..06ddcda03 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; public class RequestChannelFrameFlyweight { @@ -9,6 +10,27 @@ public class RequestChannelFrameFlyweight { private RequestChannelFrameFlyweight() {} + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + long requestN, + Payload payload) { + + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + return FLYWEIGHT.encode( + allocator, + streamId, + fragmentFollows, + complete, + false, + reqN, + payload.metadata(), + payload.data()); + } + public static ByteBuf encode( ByteBufAllocator allocator, int streamId, diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java index 680374f71..5f2d606e4 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java @@ -2,6 +2,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; public class RequestFireAndForgetFrameFlyweight { @@ -9,6 +10,13 @@ public class RequestFireAndForgetFrameFlyweight { private RequestFireAndForgetFrameFlyweight() {} + public static ByteBuf encode( + ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { + + return FLYWEIGHT.encode( + allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); + } + public static ByteBuf encode( ByteBufAllocator allocator, int streamId, diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java index efbffbd40..2e06c9b82 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java @@ -12,8 +12,7 @@ private RequestResponseFrameFlyweight() {} public static ByteBuf encode( ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - return encode( - allocator, streamId, fragmentFollows, payload.sliceMetadata(), payload.sliceData()); + return encode(allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); } public static ByteBuf encode( diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java index 3e858f5d4..171c41990 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java @@ -17,12 +17,7 @@ public static ByteBuf encode( long requestN, Payload payload) { return encode( - allocator, - streamId, - fragmentFollows, - requestN, - payload.sliceMetadata(), - payload.sliceData()); + allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); } public static ByteBuf encode( @@ -32,12 +27,7 @@ public static ByteBuf encode( int requestN, Payload payload) { return encode( - allocator, - streamId, - fragmentFollows, - requestN, - payload.sliceMetadata(), - payload.sliceData()); + allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); } public static ByteBuf encode( diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index e6178bd5b..5c0c1d74f 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -21,6 +21,7 @@ import io.rsocket.DuplexConnection; import io.rsocket.frame.FrameHeaderFlyweight; import io.rsocket.frame.FrameType; +import io.rsocket.frame.FrameUtil; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; import io.rsocket.plugins.PluginRegistry; import org.reactivestreams.Publisher; @@ -148,7 +149,7 @@ public InternalDuplexConnection( @Override public Mono send(Publisher frame) { if (debugEnabled) { - frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + f.toString())); + frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))); } return source.send(frame); @@ -157,7 +158,7 @@ public Mono send(Publisher frame) { @Override public Mono sendOne(ByteBuf frame) { if (debugEnabled) { - LOGGER.debug("sending -> " + frame.toString()); + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); } return source.sendOne(frame); @@ -168,7 +169,7 @@ public Flux receive() { return processor.flatMapMany( f -> { if (debugEnabled) { - return f.doOnNext(frame -> LOGGER.debug("receiving -> " + frame.toString())); + return f.doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); } else { return f; } diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java index d5a8fe775..25fd67097 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java @@ -26,7 +26,8 @@ public interface ClientTransport extends Transport { * Returns a {@code Publisher}, every subscription to which returns a single {@code * DuplexConnection}. * + * @param mtu The mtu used for fragmentation - if set to zero fragmentation will be disabled * @return {@code Publisher}, every subscription returns a single {@code DuplexConnection}. */ - Mono connect(); + Mono connect(int mtu); } diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java index 28af3fd4c..3adc90cc8 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java @@ -29,10 +29,11 @@ public interface ServerTransport extends Transport { * Starts this server. * * @param acceptor An acceptor to process a newly accepted {@code DuplexConnection} + * @param mtu The mtu used for fragmentation - if set to zero fragmentation will be disabled * @return A handle to retrieve information about a started server. * @throws NullPointerException if {@code acceptor} is {@code null} */ - Mono start(ConnectionAcceptor acceptor); + Mono start(ConnectionAcceptor acceptor, int mtu); /** A contract to accept a new {@code DuplexConnection}. */ interface ConnectionAcceptor extends Function> { diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java index 5275d2304..204c5d1ea 100644 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java @@ -34,9 +34,9 @@ */ public class UriTransportRegistry { private static final ClientTransport FAILED_CLIENT_LOOKUP = - () -> Mono.error(new UnsupportedOperationException()); + (mtu) -> Mono.error(new UnsupportedOperationException()); private static final ServerTransport FAILED_SERVER_LOOKUP = - acceptor -> Mono.error(new UnsupportedOperationException()); + (acceptor, mtu) -> Mono.error(new UnsupportedOperationException()); private List handlers; @@ -55,6 +55,10 @@ public static ClientTransport clientForUri(String uri) { return UriTransportRegistry.fromServices().findClient(uri); } + public static ServerTransport serverForUri(String uri) { + return UriTransportRegistry.fromServices().findServer(uri); + } + private ClientTransport findClient(String uriString) { URI uri = URI.create(uriString); @@ -68,10 +72,6 @@ private ClientTransport findClient(String uriString) { return FAILED_CLIENT_LOOKUP; } - public static ServerTransport serverForUri(String uri) { - return UriTransportRegistry.fromServices().findServer(uri); - } - private ServerTransport findServer(String uriString) { URI uri = URI.create(uriString); diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index 0b33e5e22..b91cf8ac6 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -45,62 +45,6 @@ private ByteBufPayload(final Handle handle) { this.handle = handle; } - @Override - public boolean hasMetadata() { - return metadata != null; - } - - @Override - public ByteBuf sliceMetadata() { - return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); - } - - @Override - public ByteBuf sliceData() { - return data.slice(); - } - - @Override - public ByteBufPayload retain() { - super.retain(); - return this; - } - - @Override - public ByteBufPayload retain(int increment) { - super.retain(increment); - return this; - } - - @Override - public ByteBufPayload touch() { - data.touch(); - if (metadata != null) { - metadata.touch(); - } - return this; - } - - @Override - public ByteBufPayload touch(Object hint) { - data.touch(hint); - if (metadata != null) { - metadata.touch(hint); - } - return this; - } - - @Override - protected void deallocate() { - data.release(); - data = null; - if (metadata != null) { - metadata.release(); - metadata = null; - } - handle.recycle(this); - } - /** * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" * @@ -179,4 +123,70 @@ public static Payload create(Payload payload) { payload.sliceData().retain(), payload.hasMetadata() ? payload.sliceMetadata().retain() : null); } + + @Override + public boolean hasMetadata() { + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + return data; + } + + @Override + public ByteBuf metadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return data.slice(); + } + + @Override + public ByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public ByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public ByteBufPayload touch() { + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public ByteBufPayload touch(Object hint) { + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + data = null; + if (metadata != null) { + metadata.release(); + metadata = null; + } + handle.recycle(this); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index 71bbf3874..ec73399f1 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -40,66 +40,6 @@ private DefaultPayload(ByteBuffer data, @Nullable ByteBuffer metadata) { this.metadata = metadata; } - @Override - public boolean hasMetadata() { - return metadata != null; - } - - @Override - public ByteBuf sliceMetadata() { - return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); - } - - @Override - public ByteBuf sliceData() { - return Unpooled.wrappedBuffer(data); - } - - @Override - public ByteBuffer getMetadata() { - return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); - } - - @Override - public ByteBuffer getData() { - return data.duplicate(); - } - - @Override - public int refCnt() { - return 1; - } - - @Override - public DefaultPayload retain() { - return this; - } - - @Override - public DefaultPayload retain(int increment) { - return this; - } - - @Override - public DefaultPayload touch() { - return this; - } - - @Override - public DefaultPayload touch(Object hint) { - return this; - } - - @Override - public boolean release() { - return false; - } - - @Override - public boolean release(int decrement) { - return false; - } - /** * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data)" * @@ -167,4 +107,74 @@ public static Payload create(Payload payload) { Unpooled.copiedBuffer(payload.sliceData()), payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); } + + @Override + public boolean hasMetadata() { + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); + } + + @Override + public ByteBuf sliceData() { + return Unpooled.wrappedBuffer(data); + } + + @Override + public ByteBuffer getMetadata() { + return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); + } + + @Override + public ByteBuffer getData() { + return data.duplicate(); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public DefaultPayload retain() { + return this; + } + + @Override + public DefaultPayload retain(int increment) { + return this; + } + + @Override + public DefaultPayload touch() { + return this; + } + + @Override + public DefaultPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java index d5eda1d6b..99df97d70 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java @@ -40,6 +40,16 @@ public ByteBuf sliceData() { return Unpooled.EMPTY_BUFFER; } + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + @Override public int refCnt() { return 1; diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java index 2326f338d..57070ad69 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java @@ -124,7 +124,7 @@ private static class SingleConnectionTransport implements ServerTransport start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { return Mono.just(new TestCloseable(acceptor, conn)); } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index 6b25ac902..ac4413caa 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -16,27 +16,69 @@ package io.rsocket.fragmentation; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.*; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + final class FragmentationDuplexConnectionTest { - /* + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); @SuppressWarnings("unchecked") - private final ArgumentCaptor> publishers = + private final ArgumentCaptor> publishers = ArgumentCaptor.forClass(Publisher.class); + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test void constructorInvalidMaxFragmentSize() { assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, Integer.MIN_VALUE)) - .withMessage("maxFragmentSize must be positive"); + .isThrownBy( + () -> new FragmentationDuplexConnection(delegate, allocator, Integer.MIN_VALUE, false)) + .withMessage("smallest allowed mtu size is 64 bytes"); + } + + @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") + @Test + void constructorMtuLessThanMin() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, allocator, 2, false)) + .withMessage("smallest allowed mtu size is 64 bytes"); } @DisplayName("constructor throws NullPointerException with null byteBufAllocator") @Test void constructorNullByteBufAllocator() { assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, delegate, 2)) + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, null, 64, false)) .withMessage("byteBufAllocator must not be null"); } @@ -44,339 +86,241 @@ void constructorNullByteBufAllocator() { @Test void constructorNullDelegate() { assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, null, 2)) + .isThrownBy(() -> new FragmentationDuplexConnection(null, allocator, 64, false)) .withMessage("delegate must not be null"); } @DisplayName("reassembles data") @Test void reassembleData() { - ByteBuf data = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, null, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2))); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2))); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2))); - - when(delegate.receive()).thenReturn(Flux.just(fragment1, fragment2, fragment3)); + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, DefaultPayload.create(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2) + new FragmentationDuplexConnection(delegate, allocator, 1030, false) .receive() .as(StepVerifier::create) - .expectNext(frame) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) .verifyComplete(); } @DisplayName("reassembles metadata") @Test void reassembleMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, null)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null)); - - when(delegate.receive()).thenReturn(Flux.just(fragment1, fragment2, fragment3)); + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2) + new FragmentationDuplexConnection(delegate, allocator, 1030, false) .receive() .as(StepVerifier::create) - .expectNext(frame) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) .verifyComplete(); } @DisplayName("reassembles metadata and data") @Test void reassembleMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, - 1, - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1))); - - Frame fragment4 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2))); - - Frame fragment5 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2))); - - when(delegate.receive()) - .thenReturn(Flux.just(fragment1, fragment2, fragment3, fragment4, fragment5)); + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create( + Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, DefaultPayload.create(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2) + new FragmentationDuplexConnection(delegate, allocator, 1030, false) .receive() .as(StepVerifier::create) - .expectNext(frame) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); + }) .verifyComplete(); } @DisplayName("does not reassemble a non-fragment frame") @Test void reassembleNonFragment() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, (ByteBuf) null, null)); + ByteBuf encode = + RequestResponseFrameFlyweight.encode( + allocator, 1, false, DefaultPayload.create(Unpooled.wrappedBuffer(data))); - when(delegate.receive()).thenReturn(Flux.just(frame.retain())); + when(delegate.receive()).thenReturn(Flux.just(encode)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2) + new FragmentationDuplexConnection(delegate, allocator, 1030, false) .receive() .as(StepVerifier::create) - .expectNext(frame) + .assertNext( + byteBuf -> { + Assert.assertEquals( + Unpooled.wrappedBuffer(data), RequestResponseFrameFlyweight.data(byteBuf)); + }) .verifyComplete(); } @DisplayName("does not reassemble non fragmentable frame") @Test void reassembleNonFragmentableFrame() { - Frame frame = toAbstractionLeakingFrame(DEFAULT, 1, createTestCancelFrame()); + ByteBuf encode = CancelFrameFlyweight.encode(allocator, 2); - when(delegate.receive()).thenReturn(Flux.just(frame.retain())); + when(delegate.receive()).thenReturn(Flux.just(encode)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2) + new FragmentationDuplexConnection(delegate, allocator, 1030, false) .receive() .as(StepVerifier::create) - .expectNext(frame) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.CANCEL, FrameHeaderFlyweight.frameType(byteBuf)); + }) .verifyComplete(); } @DisplayName("fragments data") @Test void sendData() { - ByteBuf data = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, null, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2))); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2))); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); - } - - @DisplayName("does not fragment with size equal to maxFragmentLength") - @Test - void sendEqualToMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2))); + ByteBuf encode = + RequestResponseFrameFlyweight.encode( + allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); when(delegate.onClose()).thenReturn(Mono.never()); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } + new FragmentationDuplexConnection(delegate, allocator, 64, false).sendOne(encode.retain()); - @DisplayName("does not fragment an already-fragmented frame") - @Test - void sendFragment() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, (ByteBuf) null, null)); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("does not fragment with size smaller than maxFragmentLength") - @Test - void sendLessThanMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(1))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("fragments metadata") - @Test - void sendMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, null)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null)); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); - } - - @DisplayName("fragments metadata and data") - @Test - void sendMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, - 1, - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1))); - - Frame fragment4 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2))); - - Frame fragment5 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .expectNext(fragment4) - .expectNext(fragment5) + .expectNextCount(17) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) .verifyComplete(); } - - @DisplayName("does not fragment non-fragmentable frame") - @Test - void sendNonFragmentable() { - Frame frame = toAbstractionLeakingFrame(DEFAULT, 1, createTestCancelFrame()); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("send throws NullPointerException with null frames") - @Test - void sendNullFrames() { - when(delegate.onClose()).thenReturn(Mono.never()); - - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, 2).send(null)) - .withMessage("frames must not be null"); - } - - @DisplayName("does not fragment with zero maxFragmentLength") - @Test - void sendZeroMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 0).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - }*/ } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java new file mode 100644 index 000000000..df68da9a5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java @@ -0,0 +1,55 @@ +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameUtil; +import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +public class FragmentationIntegrationTest { + private static byte[] data = new byte[128]; + private static byte[] metadata = new byte[128]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + @DisplayName("fragments and reassembles data") + @Test + void fragmentAndReassembleData() { + ByteBuf frame = + PayloadFrameFlyweight.encodeNextComplete(allocator, 2, DefaultPayload.create(data)); + System.out.println(FrameUtil.toString(frame)); + + Publisher fragments = + FrameFragmenter.fragmentFrame( + allocator, 64, frame, FrameHeaderFlyweight.frameType(frame), false); + + FrameReassembler reassembler = new FrameReassembler(allocator); + + ByteBuf assembled = + Flux.from(fragments) + .doOnNext(byteBuf -> System.out.println(FrameUtil.toString(byteBuf))) + .handle(reassembler::reassembleFrame) + .blockLast(); + + System.out.println("assembled"); + String s = FrameUtil.toString(assembled); + System.out.println(s); + + Assert.assertEquals( + FrameHeaderFlyweight.frameType(frame), FrameHeaderFlyweight.frameType(assembled)); + Assert.assertEquals(frame.readableBytes(), assembled.readableBytes()); + Assert.assertEquals(PayloadFrameFlyweight.data(frame), PayloadFrameFlyweight.data(assembled)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java index 8cf3edb96..f5a013357 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java @@ -16,158 +16,337 @@ package io.rsocket.fragmentation; -final class FrameFragmenterTest { - /* - @DisplayName("constructor throws NullPointerException with null ByteBufAllocator") - @Test - void constructorNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> new FrameFragmenter(null, 2)) - .withMessage("byteBufAllocator must not be null"); - } - - @DisplayName("fragments data") - @Test - void fragmentData() { - ByteBuf data = getRandomByteBuf(6); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, null, data); +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.frame.*; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2)); +final class FrameFragmenterTest { + private static byte[] data = new byte[4096]; + private static byte[] metadata = new byte[4096]; - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2)); + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2)); + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); + @Test + void testGettingData() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + ByteBuf fnf = + RequestFireAndForgetFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + ByteBuf rs = + RequestStreamFrameFlyweight.encode(allocator, 1, true, 1, DefaultPayload.create(data)); + ByteBuf rc = + RequestChannelFrameFlyweight.encode( + allocator, 1, true, false, 1, DefaultPayload.create(data)); + + ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(fnf, FrameType.REQUEST_FNF); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(rs, FrameType.REQUEST_STREAM); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(rc, FrameType.REQUEST_CHANNEL); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); } - @DisplayName("does not fragment with size equal to maxFragmentLength") @Test - void fragmentEqualToMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2)); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + void testGettingMetadata() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode( + allocator, 1, true, DefaultPayload.create(data, metadata)); + ByteBuf fnf = + RequestFireAndForgetFrameFlyweight.encode( + allocator, 1, true, DefaultPayload.create(data, metadata)); + ByteBuf rs = + RequestStreamFrameFlyweight.encode( + allocator, 1, true, 1, DefaultPayload.create(data, metadata)); + ByteBuf rc = + RequestChannelFrameFlyweight.encode( + allocator, 1, true, false, 1, DefaultPayload.create(data, metadata)); + + ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(fnf, FrameType.REQUEST_FNF); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(rs, FrameType.REQUEST_STREAM); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(rc, FrameType.REQUEST_CHANNEL); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); } - @DisplayName("does not fragment an already-fragmented frame") @Test - void fragmentFragment() { - PayloadFrame frame = createPayloadFrame(DEFAULT, true, true, (ByteBuf) null, null); + void returnEmptBufferWhenNoMetadataPresent() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); + data.release(); } - @DisplayName("does not fragment with size smaller than maxFragmentLength") + @DisplayName("encode first frame") @Test - void fragmentLessThanMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(1)); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + void encodeFirstFrameWithData() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rr, + FrameType.REQUEST_RESPONSE, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + + ByteBuf data = RequestResponseFrameFlyweight.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); } - @DisplayName("fragments metadata") + @DisplayName("encode first channel frame") @Test - void fragmentMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, null); - - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); - - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); - - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); + void encodeFirstWithDataChannel() { + ByteBuf rc = + RequestChannelFrameFlyweight.encode( + allocator, 1, true, false, 10, DefaultPayload.create(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_CHANNEL, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderFlyweight.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); + Assert.assertEquals(10, RequestChannelFrameFlyweight.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + + ByteBuf data = RequestChannelFrameFlyweight.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); } - @DisplayName("fragments metadata and data") + @DisplayName("encode first stream frame") @Test - void fragmentMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, data); - - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); - - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); - - PayloadFrame fragment3 = - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1)); - - PayloadFrame fragment4 = createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2)); - - PayloadFrame fragment5 = createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2)); + void encodeFirstWithDataStream() { + ByteBuf rc = + RequestStreamFrameFlyweight.encode(allocator, 1, true, 50, DefaultPayload.create(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_STREAM, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderFlyweight.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameFlyweight.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + + ByteBuf data = RequestStreamFrameFlyweight.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); + } - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .expectNext(fragment4) - .expectNext(fragment5) - .verifyComplete(); + @DisplayName("encode first frame with only metadata") + @Test + void encodeFirstFrameWithMetadata() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rr, + FrameType.REQUEST_RESPONSE, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + + ByteBuf data = RequestResponseFrameFlyweight.data(fragment); + Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); + + Assert.assertTrue(FrameHeaderFlyweight.hasMetadata(fragment)); } - @DisplayName("does not fragment non-fragmentable frame") + @DisplayName("encode first stream frame with data and metadata") @Test - void fragmentNonFragmentable() { - CancelFrame frame = createTestCancelFrame(); + void encodeFirstWithDataAndMetadataStream() { + ByteBuf rc = + RequestStreamFrameFlyweight.encode( + allocator, 1, true, 50, DefaultPayload.create(data, metadata)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_STREAM, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderFlyweight.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameFlyweight.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + + ByteBuf data = RequestStreamFrameFlyweight.data(fragment); + Assert.assertEquals(0, data.readableBytes()); + + ByteBuf metadata = RequestStreamFrameFlyweight.metadata(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.metadata).readSlice(metadata.readableBytes()); + Assert.assertEquals(byteBuf, metadata); + + Assert.assertTrue(FrameHeaderFlyweight.hasMetadata(fragment)); + } - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) + @DisplayName("fragments frame with only data") + @Test + void fragmentData() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .expectNextCount(1) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) + .expectNextCount(2) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) .verifyComplete(); } - @DisplayName("fragment throws NullPointerException with null frame") + @DisplayName("fragments frame with only metadata") @Test - void fragmentWithNullFrame() { - assertThatNullPointerException() - .isThrownBy(() -> new FrameFragmenter(DEFAULT, 2).fragment(null)) - .withMessage("frame must not be null"); + void fragmentMetadata() { + ByteBuf rr = + RequestStreamFrameFlyweight.encode( + allocator, + 1, + true, + 10, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .expectNextCount(1) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderFlyweight.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) + .expectNextCount(2) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) + .verifyComplete(); } - @DisplayName("does not fragment with zero maxFragmentLength") + @DisplayName("fragments frame with data and metadata") @Test - void fragmentZeroMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2)); - - new FrameFragmenter(DEFAULT, 0) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) + void fragmentDataAndMetadata() { + ByteBuf rr = + RequestResponseFrameFlyweight.encode( + allocator, 1, true, DefaultPayload.create(data, metadata)); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .assertNext( + byteBuf -> { + Assert.assertEquals( + FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) + .expectNextCount(6) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + }) .verifyComplete(); - }*/ + } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java index 467f6c2e7..6e0d0dc1b 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java @@ -16,108 +16,464 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.frame.*; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + final class FrameReassemblerTest { - /* - @DisplayName("createFrameReassembler throws NullPointerException") - @Test - void createFrameReassemblerNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> createFrameReassembler(null)) - .withMessage("byteBufAllocator must not be null"); + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); } + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + @DisplayName("reassembles data") @Test void reassembleData() { - ByteBuf data = getRandomByteBuf(6); + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, true, false, true, DefaultPayload.create(data)), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, DefaultPayload.create(data))); + + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, null, data); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2)); + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data)); + + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); + } + + @DisplayName("pass through frames without follows") + @Test + void passthrough() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode(allocator, 1, false, DefaultPayload.create(data))); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2)); + FrameReassembler reassembler = new FrameReassembler(allocator); - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2)); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents(true, Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); } @DisplayName("reassembles metadata") @Test void reassembleMetadata() { - ByteBuf metadata = getRandomByteBuf(6); + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, null); + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null); - - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); - - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) + .verifyComplete(); } - @DisplayName("reassembles metadata and data") + @DisplayName("reassembles metadata request channel") @Test - void reassembleMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); + void reassembleMetadataChannel() { + List byteBufs = + Arrays.asList( + RequestChannelFrameFlyweight.encode( + allocator, + 1, + true, + false, + 100, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, data); + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - PayloadFrame fragment3 = - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1)); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestChannelFrameFlyweight.metadata(byteBuf); + Assert.assertEquals(metadata, m); + Assert.assertEquals(100, RequestChannelFrameFlyweight.initialRequestN(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + + ReferenceCountUtil.safeRelease(metadata); + } + + @DisplayName("reassembles metadata request stream") + @Test + void reassembleMetadataStream() { + List byteBufs = + Arrays.asList( + RequestStreamFrameFlyweight.encode( + allocator, + 1, + true, + 250, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + false, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); - PayloadFrame fragment4 = createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2)); + FrameReassembler reassembler = new FrameReassembler(allocator); - PayloadFrame fragment5 = createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2)); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isNull(); - assertThat(frameReassembler.reassemble(fragment4)).isNull(); - assertThat(frameReassembler.reassemble(fragment5)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestStreamFrameFlyweight.metadata(byteBuf); + Assert.assertEquals(metadata, m); + Assert.assertEquals(250, RequestChannelFrameFlyweight.initialRequestN(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + + ReferenceCountUtil.safeRelease(metadata); } - @DisplayName("does not reassemble a non-fragment frame") + @DisplayName("reassembles metadata and data") @Test - void reassembleNonFragment() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, true, (ByteBuf) null, null); + void reassembleMetadataAndData() { + + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create( + Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, 1, false, false, true, DefaultPayload.create(data))); + + FrameReassembler reassembler = new FrameReassembler(allocator); + + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - assertThat(createFrameReassembler(DEFAULT).reassemble(frame)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); + ReferenceCountUtil.safeRelease(metadata); } - @DisplayName("does not reassemble non fragmentable frame") + @DisplayName("cancel removes inflight frames") @Test - void reassembleNonFragmentableFrame() { - CancelFrame frame = createTestCancelFrame(); + public void cancelBeforeAssembling() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create( + Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); - assertThat(createFrameReassembler(DEFAULT).reassemble(frame)).isEqualTo(frame); + FrameReassembler reassembler = new FrameReassembler(allocator); + Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); + + Assert.assertTrue(reassembler.headers.containsKey(1)); + Assert.assertTrue(reassembler.metadata.containsKey(1)); + Assert.assertTrue(reassembler.data.containsKey(1)); + + Flux.just(CancelFrameFlyweight.encode(allocator, 1)) + .handle(reassembler::reassembleFrame) + .blockLast(); + + Assert.assertFalse(reassembler.headers.containsKey(1)); + Assert.assertFalse(reassembler.metadata.containsKey(1)); + Assert.assertFalse(reassembler.data.containsKey(1)); } - @DisplayName("reassemble throws NullPointerException with null frame") + @DisplayName("dispose should clean up maps") @Test - void reassembleNullFrame() { - assertThatNullPointerException() - .isThrownBy(() -> createFrameReassembler(DEFAULT).reassemble(null)) - .withMessage("frame must not be null"); - }*/ + public void dispose() { + List byteBufs = + Arrays.asList( + RequestResponseFrameFlyweight.encode( + allocator, + 1, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), + PayloadFrameFlyweight.encode( + allocator, + 1, + true, + false, + true, + DefaultPayload.create( + Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); + + FrameReassembler reassembler = new FrameReassembler(allocator); + Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); + + Assert.assertTrue(reassembler.headers.containsKey(1)); + Assert.assertTrue(reassembler.metadata.containsKey(1)); + Assert.assertTrue(reassembler.data.containsKey(1)); + + reassembler.dispose(); + + Assert.assertFalse(reassembler.headers.containsKey(1)); + Assert.assertFalse(reassembler.metadata.containsKey(1)); + Assert.assertFalse(reassembler.data.containsKey(1)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java index 46634e94b..526757fbe 100644 --- a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java +++ b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java @@ -36,7 +36,7 @@ public Optional buildClient(URI uri) { return Optional.empty(); } - return Optional.of(() -> Mono.just(new TestDuplexConnection())); + return Optional.of((mtu) -> Mono.just(new TestDuplexConnection())); } @Override diff --git a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java index 9e7b92f65..7aeef708f 100644 --- a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java +++ b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java @@ -28,7 +28,7 @@ public class UriTransportRegistryTest { public void testTestRegistered() { ClientTransport test = UriTransportRegistry.clientForUri("test://test"); - DuplexConnection duplexConnection = test.connect().block(); + DuplexConnection duplexConnection = test.connect(0).block(); assertTrue(duplexConnection instanceof TestDuplexConnection); } @@ -37,6 +37,6 @@ public void testTestRegistered() { public void testTestUnregistered() { ClientTransport test = UriTransportRegistry.clientForUri("mailto://bonson@baulsupp.net"); - test.connect().block(); + test.connect(0).block(); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index 91e8c3e57..55d6aaf93 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -17,7 +17,9 @@ package io.rsocket.transport.local; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalServerTransport.ServerDuplexConnectionAcceptor; @@ -51,8 +53,7 @@ public static LocalClientTransport create(String name) { return new LocalClientTransport(name); } - @Override - public Mono connect() { + private Mono connect() { return Mono.defer( () -> { ServerDuplexConnectionAcceptor server = LocalServerTransport.findServer(name); @@ -69,4 +70,17 @@ public Mono connect() { return Mono.just((DuplexConnection) new LocalDuplexConnection(in, out, closeNotifier)); }); } + + @Override + public Mono connect(int mtu) { + Mono connect = connect(); + if (mtu > 0) { + return connect.map( + duplexConnection -> + new FragmentationDuplexConnection( + duplexConnection, ByteBufAllocator.DEFAULT, mtu, false)); + } else { + return connect; + } + } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index 68e7d462f..c1850b81c 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -16,8 +16,10 @@ package io.rsocket.transport.local; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; @@ -76,6 +78,20 @@ public static void dispose(String name) { registry.remove(name); } + /** + * Retrieves an instance of {@link ServerDuplexConnectionAcceptor} based on the name of its {@code + * LocalServerTransport}. Returns {@code null} if that server is not registered. + * + * @param name the name of the server to retrieve + * @return the server if it has been registered, {@code null} otherwise + * @throws NullPointerException if {@code name} is {@code null} + */ + static @Nullable ServerDuplexConnectionAcceptor findServer(String name) { + Objects.requireNonNull(name, "name must not be null"); + + return registry.get(name); + } + /** * Returns a new {@link LocalClientTransport} that is connected to this {@code * LocalServerTransport}. @@ -88,13 +104,13 @@ public LocalClientTransport clientTransport() { } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return Mono.create( sink -> { ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = - new ServerDuplexConnectionAcceptor(name, acceptor); + new ServerDuplexConnectionAcceptor(name, acceptor, mtu); if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { throw new IllegalStateException("name already registered: " + name); @@ -104,20 +120,6 @@ public Mono start(ConnectionAcceptor acceptor) { }); } - /** - * Retrieves an instance of {@link ServerDuplexConnectionAcceptor} based on the name of its {@code - * LocalServerTransport}. Returns {@code null} if that server is not registered. - * - * @param name the name of the server to retrieve - * @return the server if it has been registered, {@code null} otherwise - * @throws NullPointerException if {@code name} is {@code null} - */ - static @Nullable ServerDuplexConnectionAcceptor findServer(String name) { - Objects.requireNonNull(name, "name must not be null"); - - return registry.get(name); - } - /** * Returns the name of this instance. * @@ -138,6 +140,8 @@ static class ServerDuplexConnectionAcceptor implements Consumer onClose = MonoProcessor.create(); + private final int mtu; + /** * Creates a new instance * @@ -145,17 +149,24 @@ static class ServerDuplexConnectionAcceptor implements Consumer 0) { + duplexConnection = + new FragmentationDuplexConnection( + duplexConnection, ByteBufAllocator.DEFAULT, mtu, false); + } + acceptor.apply(duplexConnection).subscribe(); } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java index 92478b0bd..4cfee9a01 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -32,8 +32,8 @@ void connect() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -43,7 +43,7 @@ void connect() { @Test void connectNoServer() { LocalClientTransport.create("test-name") - .connect() + .connect(0) .as(StepVerifier::create) .verifyErrorMessage("Could not find server: test-name"); } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java index 7fb350432..1656ed08d 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -63,7 +63,7 @@ void findServer() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -97,7 +97,7 @@ void named() { @Test void start() { LocalServerTransport.createEphemeral() - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -107,7 +107,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null)) + .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java index b84201ac9..7b33bcf94 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java @@ -88,22 +88,19 @@ private ChannelPromise writeCleanupPromise(V poll) { .newPromise() .addListener( future -> { - try { - if (requested != Long.MAX_VALUE) { - requested--; - } - requestedUpstream--; - pending--; - - InnerSubscriber is = (InnerSubscriber) INNER_SUBSCRIBER.get(SendPublisher.this); - if (is != null) { - is.tryRequestMoreUpstream(); - tryComplete(is); - } - } finally { - if (poll.refCnt() > 0) { - ReferenceCountUtil.safeRelease(poll); - } + if (requested != Long.MAX_VALUE) { + requested--; + } + requestedUpstream--; + pending--; + + InnerSubscriber is = (InnerSubscriber) INNER_SUBSCRIBER.get(SendPublisher.this); + if (is != null) { + is.tryRequestMoreUpstream(); + tryComplete(is); + } + if (poll.refCnt() > 0) { + ReferenceCountUtil.safeRelease(poll); } }); } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index 57e3ff0a9..9b2e60d5c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -35,12 +35,25 @@ public final class TcpDuplexConnection implements DuplexConnection { private final Connection connection; private final Disposable channelClosed; private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private final boolean encodeLength; + /** * Creates a new instance * - * @param connection the {@link Connection} to for managing the server + * @param connection the {@link Connection} for managing the server */ public TcpDuplexConnection(Connection connection) { + this(connection, true); + } + + /** + * Creates a new instance + * + * @param encodeLength indicates if this connection should encode the length or not. + * @param connection the {@link Connection} to for managing the server + */ + public TcpDuplexConnection(Connection connection, boolean encodeLength) { + this.encodeLength = encodeLength; this.connection = Objects.requireNonNull(connection, "connection must not be null"); this.channelClosed = FutureMono.from(connection.channel().closeFuture()) @@ -77,15 +90,7 @@ public Mono onClose() { @Override public Flux receive() { - return connection - .inbound() - .receive() - .map( - byteBuf -> { - ByteBuf frame = FrameLengthFlyweight.frame(byteBuf); - frame.retain(); - return frame; - }); + return connection.inbound().receive().map(this::decode); } @Override @@ -101,20 +106,29 @@ public Mono send(Publisher frames) { queueSubscription, frameFlux, connection.channel(), - frame -> - FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame) - .retain(), + this::encode, ByteBuf::readableBytes); } else { return new SendPublisher<>( - frameFlux, - connection.channel(), - frame -> - FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame) - .retain(), - ByteBuf::readableBytes); + frameFlux, connection.channel(), this::encode, ByteBuf::readableBytes); } }) .then(); } + + private ByteBuf encode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame).retain(); + } else { + return frame; + } + } + + private ByteBuf decode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthFlyweight.frame(frame).retain(); + } else { + return frame; + } + } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index 291494f3b..7c1070317 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -16,7 +16,9 @@ package io.rsocket.transport.netty.client; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -91,10 +93,18 @@ public static TcpClientTransport create(TcpClient client) { } @Override - public Mono connect() { + public Mono connect(int mtu) { return client .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec())) .connect() - .map(TcpDuplexConnection::new); + .map( + c -> { + if (mtu > 0) { + return new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), ByteBufAllocator.DEFAULT, mtu, true); + } else { + return new TcpDuplexConnection(c); + } + }); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 111a37e98..99de91d41 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -19,7 +19,9 @@ import static io.rsocket.transport.netty.UriUtils.getPort; import static io.rsocket.transport.netty.UriUtils.isSecure; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -133,14 +135,31 @@ public static WebsocketClientTransport create(HttpClient client, String path) { return new WebsocketClientTransport(client, path); } + private static TcpClient createClient(URI uri) { + if (isSecure(uri)) { + return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443)); + } else { + return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80)); + } + } + @Override - public Mono connect() { + public Mono connect(int mtu) { return client .headers(headers -> transportHeaders.get().forEach(headers::set)) .websocket() .uri(path) .connect() - .map(WebsocketDuplexConnection::new); + .map( + c -> { + DuplexConnection connection = new WebsocketDuplexConnection(c); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + connection, ByteBufAllocator.DEFAULT, mtu, false); + } + return connection; + }); } @Override @@ -148,12 +167,4 @@ public void setTransportHeaders(Supplier> transportHeaders) this.transportHeaders = Objects.requireNonNull(transportHeaders, "transportHeaders must not be null"); } - - private static TcpClient createClient(URI uri) { - if (isSecure(uri)) { - return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443)); - } else { - return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80)); - } - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index 6965499a8..11adae8a6 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -16,6 +16,9 @@ package io.rsocket.transport.netty.server; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -89,14 +92,21 @@ public static TcpServerTransport create(TcpServer server) { } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return server .doOnConnection( c -> { c.addHandlerLast(new RSocketLengthCodec()); - TcpDuplexConnection connection = new TcpDuplexConnection(c); + DuplexConnection connection; + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), ByteBufAllocator.DEFAULT, mtu, true); + } else { + connection = new TcpDuplexConnection(c); + } acceptor.apply(connection).then(Mono.never()).subscribe(c.disposeSubscriber()); }) .bind() diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index b9bb43e6e..fc2ab28bb 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -16,19 +16,18 @@ package io.rsocket.transport.netty.server; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.util.Objects; -import java.util.function.BiFunction; import java.util.function.Consumer; -import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServerRoutes; -import reactor.netty.http.websocket.WebsocketInbound; -import reactor.netty.http.websocket.WebsocketOutbound; /** * An implementation of {@link ServerTransport} that connects via Websocket and listens on specified @@ -58,34 +57,26 @@ public WebsocketRouteTransport( } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return server .route( routes -> { routesBuilder.accept(routes); - routes.ws(path, newHandler(acceptor)); + routes.ws( + path, + (in, out) -> { + DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + connection, ByteBufAllocator.DEFAULT, mtu, false); + } + return acceptor.apply(connection).then(out.neverComplete()); + }); }) .bind() .map(CloseableChannel::new); } - - /** - * Creates a new Websocket handler - * - * @param acceptor the {@link ConnectionAcceptor} to use with the handler - * @return a new Websocket handler - * @throws NullPointerException if {@code acceptor} is {@code null} - */ - static BiFunction> newHandler( - ConnectionAcceptor acceptor) { - - Objects.requireNonNull(acceptor, "acceptor must not be null"); - - return (in, out) -> { - WebsocketDuplexConnection connection = new WebsocketDuplexConnection((Connection) in); - return acceptor.apply(connection).then(out.neverComplete()); - }; - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index b6ef5eaea..9e1af2395 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -16,15 +16,20 @@ package io.rsocket.transport.netty.server; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; +import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; /** @@ -101,14 +106,23 @@ public void setTransportHeaders(Supplier> transportHeaders) } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return server .handle( (request, response) -> { transportHeaders.get().forEach(response::addHeader); - return response.sendWebsocket(WebsocketRouteTransport.newHandler(acceptor)); + return response.sendWebsocket( + (in, out) -> { + DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + connection, ByteBufAllocator.DEFAULT, mtu, false); + } + return acceptor.apply(connection).then(out.neverComplete()); + }); }) .bind() .map(CloseableChannel::new); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java index 388001fb6..e0bdb9cd7 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java @@ -37,8 +37,8 @@ void connect() { TcpServerTransport serverTransport = TcpServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(context -> TcpClientTransport.create(context.address()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(context -> TcpClientTransport.create(context.address()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -47,7 +47,7 @@ void connect() { @DisplayName("create generates error if server not started") @Test void connectNoServer() { - TcpClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + TcpClientTransport.create(8000).connect(0).as(StepVerifier::create).verifyError(); } @DisplayName("creates client with BindAddress") diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java index 202c5b3f3..58a8776a4 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -39,8 +39,8 @@ void connect() { WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(context -> WebsocketClientTransport.create(context.address()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(context -> WebsocketClientTransport.create(context.address()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -49,7 +49,7 @@ void connect() { @DisplayName("create generates error if server not started") @Test void connectNoServer() { - WebsocketClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + WebsocketClientTransport.create(8000).connect(0).as(StepVerifier::create).verifyError(); } @DisplayName("creates client with BindAddress") diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java index 15a216b96..84c185e26 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java @@ -87,7 +87,7 @@ void start() { TcpServerTransport serverTransport = TcpServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -97,7 +97,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> TcpServerTransport.create(8000).start(null)) + .isThrownBy(() -> TcpServerTransport.create(8000).start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java index 66822890a..e94bef13c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java @@ -16,7 +16,6 @@ package io.rsocket.transport.netty.server; -import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; import org.junit.jupiter.api.DisplayName; @@ -57,20 +56,6 @@ void constructorNullServer() { .withMessage("server must not be null"); } - @DisplayName("creates a new handler") - @Test - void newHandler() { - assertThat(WebsocketRouteTransport.newHandler(duplexConnection -> null)).isNotNull(); - } - - @DisplayName("newHandler throws NullPointerException with null acceptor") - @Test - void newHandlerNullAcceptor() { - assertThatNullPointerException() - .isThrownBy(() -> WebsocketRouteTransport.newHandler(null)) - .withMessage("acceptor must not be null"); - } - @DisplayName("starts server") @Test void start() { @@ -78,7 +63,7 @@ void start() { new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -91,7 +76,7 @@ void startNullAcceptor() { .isThrownBy( () -> new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path") - .start(null)) + .start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index d1a6b374e..7a5e360d2 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -102,7 +102,7 @@ void start() { WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -112,7 +112,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null)) + .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index 49b11d6fb..7150e3f0f 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -24,6 +24,7 @@ +