diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 545dd863f..6d0acd25f 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -211,9 +211,7 @@ public Mono start() { dataMimeType, setupPayload); - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } + connection = new FragmentationDuplexConnection(connection, mtu); ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, plugins); diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index fe36ed074..f9f25674d 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -68,7 +68,7 @@ public FragmentationDuplexConnection(DuplexConnection delegate, int maxFragmentS * * @param byteBufAllocator the {@link ByteBufAllocator} to use * @param delegate the {@link DuplexConnection} to decorate - * @param maxFragmentSize the maximum fragment size + * @param maxFragmentSize the maximum fragment size. A value of 0 indicates that frames should not be fragmented. * @throws NullPointerException if {@code byteBufAllocator} or {@code delegate} are {@code null} * @throws IllegalArgumentException if {@code maxFragmentSize} is not {@code positive} */ @@ -79,7 +79,7 @@ public FragmentationDuplexConnection( Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); - NumberUtils.requirePositive(maxFragmentSize, "maxFragmentSize must be positive"); + NumberUtils.requireNonNegative(maxFragmentSize, "maxFragmentSize must be positive"); this.frameFragmenter = new FrameFragmenter(byteBufAllocator, maxFragmentSize); } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index e178bc971..a0c7911f9 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -124,7 +124,7 @@ private int getFragmentableLength(FragmentableFrame fragmentableFrame) { } private boolean shouldFragment(Frame frame) { - if (!(frame instanceof FragmentableFrame)) { + if (maxFragmentSize == 0 || !(frame instanceof FragmentableFrame)) { return false; } diff --git a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java index 672c852d0..12e3cee45 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java @@ -37,6 +37,25 @@ public final class NumberUtils { private NumberUtils() {} + /** + * Requires that an {@code int} is greater than or equal to zero. + * + * @param i the {@code int} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code int} if greater than or equal to zero + * @throws IllegalArgumentException if {@code i} is less than zero + */ + public static int requireNonNegative(int i, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (i < 0) { + throw new IllegalArgumentException(message); + } + + return i; + } + /** * Requires that a {@code long} is greater than zero. * diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index 7ceb3b20e..555530641 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -47,11 +47,11 @@ final class FragmentationDuplexConnectionTest { private final ArgumentCaptor> publishers = ArgumentCaptor.forClass(Publisher.class); - @DisplayName("constructor throws NullPointerException with invalid maxFragmentLength") + @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test void constructorInvalidMaxFragmentSize() { assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, 0)) + .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, Integer.MIN_VALUE)) .withMessage("maxFragmentSize must be positive"); } @@ -366,4 +366,17 @@ void sendNullFrames() { .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, 2).send(null)) .withMessage("frames must not be null"); } + + @DisplayName("does not fragment with zero maxFragmentLength") + @Test + void sendZeroMaxFragmentLength() { + Frame frame = + toAbstractionLeakingFrame( + DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2))); + + new FragmentationDuplexConnection(DEFAULT, delegate, 0).sendOne(frame); + verify(delegate).send(publishers.capture()); + + StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java index d69492f3e..efa1b5357 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java @@ -173,4 +173,16 @@ void fragmentWithNullFrame() { .isThrownBy(() -> new FrameFragmenter(DEFAULT, 2).fragment(null)) .withMessage("frame must not be null"); } + + @DisplayName("does not fragment with zero maxFragmentLength") + @Test + void fragmentZeroMaxFragmentLength() { + PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2)); + + new FrameFragmenter(DEFAULT, 0) + .fragment(frame) + .as(StepVerifier::create) + .expectNext(frame) + .verifyComplete(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java index f885a221d..6ce023783 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java @@ -25,7 +25,37 @@ final class NumberUtilsTest { - @DisplayName("returns long value with positive int") + @DisplayName("returns int value with postitive int") + @Test + void requireNonNegativeInt() { + assertThat(NumberUtils.requireNonNegative(Integer.MAX_VALUE, "test-message")) + .isEqualTo(Integer.MAX_VALUE); + } + + @DisplayName( + "requireNonNegative with int argument throws IllegalArgumentException with negative value") + @Test + void requireNonNegativeIntNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requireNonNegative with int argument throws NullPointerException with null message") + @Test + void requireNonNegativeIntNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requireNonNegative returns int value with zero") + @Test + void requireNonNegativeIntZero() { + assertThat(NumberUtils.requireNonNegative(0, "test-message")).isEqualTo(0); + } + + @DisplayName("requirePositive returns int value with positive int") @Test void requirePositiveInt() { assertThat(NumberUtils.requirePositive(Integer.MAX_VALUE, "test-message")) @@ -52,13 +82,12 @@ void requirePositiveIntNullMessage() { @DisplayName("requirePositive with int argument throws IllegalArgumentException with zero value") @Test void requirePositiveIntZero() { - assertThatIllegalArgumentException() .isThrownBy(() -> NumberUtils.requirePositive(0, "test-message")) .withMessage("test-message"); } - @DisplayName("returns long value with positive long") + @DisplayName("requirePositive returns long value with positive long") @Test void requirePositiveLong() { assertThat(NumberUtils.requirePositive(Long.MAX_VALUE, "test-message")) @@ -85,7 +114,6 @@ void requirePositiveLongNullMessage() { @DisplayName("requirePositive with long argument throws IllegalArgumentException with zero value") @Test void requirePositiveLongZero() { - assertThatIllegalArgumentException() .isThrownBy(() -> NumberUtils.requirePositive(0L, "test-message")) .withMessage("test-message");