From 87cd805e147ffd4a6683c76063897895afdba3ca Mon Sep 17 00:00:00 2001 From: Oleh Dokuka Date: Tue, 12 Mar 2019 21:59:44 +0200 Subject: [PATCH] fixes uncontrolled data sending in case of direct propagation of request from requester (#595) * fixes uncontrolled data sending in case of direct propagation of request from requester * fixes timeout typo * replaces forEach with explicit loop * optimize access to limitableRequestPublisher Signed-off-by: Oleh Dokuka --- .../main/java/io/rsocket/RSocketClient.java | 11 ++- .../main/java/io/rsocket/RSocketServer.java | 52 +++++++++- .../internal/LimitableRequestPublisher.java | 40 +++++--- .../rsocket/internal/UnboundedProcessor.java | 5 +- .../java/io/rsocket/RSocketClientTest.java | 25 +++++ .../java/io/rsocket/RSocketServerTest.java | 95 +++++++++++++++++++ .../integration/TcpIntegrationTest.java | 10 +- 7 files changed, 212 insertions(+), 26 deletions(-) diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java index c92c604c8..9635ce25b 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java @@ -82,7 +82,13 @@ class RSocketClient implements RSocket { connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer); connection - .send(sendProcessor) + .send( + sendProcessor.doOnRequest( + r -> { + for (LimitableRequestPublisher lrp : senders.values()) { + lrp.increaseInternalLimit(r); + } + })) .doFinally(this::handleSendProcessorCancel) .subscribe(null, this::handleSendProcessorError); @@ -294,7 +300,8 @@ private Flux handleChannel(Flux request) { .transform( f -> { LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); + LimitableRequestPublisher.wrap( + f, sendProcessor.available()); // Need to set this to one for first the frame wrapped.increaseRequestLimit(1); senders.put(streamId, wrapped); diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index a226b5c06..35807963c 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -50,6 +50,7 @@ class RSocketServer implements ResponderRSocket { private final Function frameDecoder; private final Consumer errorConsumer; + private final Map sendingLimitableSubscriptions; private final Map sendingSubscriptions; private final Map> channelProcessors; @@ -81,6 +82,7 @@ class RSocketServer implements ResponderRSocket { this.connection = connection; this.frameDecoder = frameDecoder; this.errorConsumer = errorConsumer; + this.sendingLimitableSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>()); this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>()); this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>()); @@ -89,7 +91,13 @@ class RSocketServer implements ResponderRSocket { this.sendProcessor = new UnboundedProcessor<>(); connection - .send(sendProcessor) + .send( + sendProcessor.doOnRequest( + r -> { + for (LimitableRequestPublisher lrp : sendingLimitableSubscriptions.values()) { + lrp.increaseInternalLimit(r); + } + })) .doFinally(this::handleSendProcessorCancel) .subscribe(null, this::handleSendProcessorError); @@ -135,6 +143,17 @@ private void handleSendProcessorError(Throwable t) { } }); + sendingLimitableSubscriptions + .values() + .forEach( + subscription -> { + try { + subscription.cancel(); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); + channelProcessors .values() .forEach( @@ -163,6 +182,17 @@ private void handleSendProcessorCancel(SignalType t) { } }); + sendingLimitableSubscriptions + .values() + .forEach( + subscription -> { + try { + subscription.cancel(); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); + channelProcessors .values() .forEach( @@ -258,6 +288,9 @@ private void cleanup() { private synchronized void cleanUpSendingSubscriptions() { sendingSubscriptions.values().forEach(Subscription::cancel); sendingSubscriptions.clear(); + + sendingLimitableSubscriptions.values().forEach(Subscription::cancel); + sendingLimitableSubscriptions.clear(); } private synchronized void cleanUpChannelProcessors() { @@ -373,12 +406,12 @@ private void handleStream(int streamId, Flux response, int initialReque .transform( frameFlux -> { LimitableRequestPublisher payloads = - LimitableRequestPublisher.wrap(frameFlux); - sendingSubscriptions.put(streamId, payloads); + LimitableRequestPublisher.wrap(frameFlux, sendProcessor.available()); + sendingLimitableSubscriptions.put(streamId, payloads); payloads.increaseRequestLimit(initialRequestN); return payloads; }) - .doFinally(signalType -> sendingSubscriptions.remove(streamId)) + .doFinally(signalType -> sendingLimitableSubscriptions.remove(streamId)) .subscribe( payload -> { final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload); @@ -423,6 +456,11 @@ private void handleKeepAliveFrame(Frame frame) { private void handleCancelFrame(int streamId) { Subscription subscription = sendingSubscriptions.remove(streamId); + + if (subscription == null) { + subscription = sendingLimitableSubscriptions.get(streamId); + } + if (subscription != null) { subscription.cancel(); } @@ -434,7 +472,11 @@ private void handleError(int streamId, Throwable t) { } private void handleRequestN(int streamId, Frame frame) { - final Subscription subscription = sendingSubscriptions.get(streamId); + Subscription subscription = sendingSubscriptions.get(streamId); + + if (subscription == null) { + subscription = sendingLimitableSubscriptions.get(streamId); + } if (subscription != null) { int n = Frame.RequestN.requestN(frame); subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java index 17372ea01..d5a05375d 100755 --- a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java @@ -31,6 +31,8 @@ public class LimitableRequestPublisher extends Flux implements Subscriptio private final AtomicBoolean canceled; + private final long prefetch; + private long internalRequested; private long externalRequested; @@ -39,13 +41,14 @@ public class LimitableRequestPublisher extends Flux implements Subscriptio private volatile @Nullable Subscription internalSubscription; - private LimitableRequestPublisher(Publisher source) { + private LimitableRequestPublisher(Publisher source, long prefetch) { this.source = source; + this.prefetch = prefetch; this.canceled = new AtomicBoolean(); } - public static LimitableRequestPublisher wrap(Publisher source) { - return new LimitableRequestPublisher<>(source); + public static LimitableRequestPublisher wrap(Publisher source, long prefetch) { + return new LimitableRequestPublisher<>(source, prefetch); } @Override @@ -60,6 +63,7 @@ public void subscribe(CoreSubscriber destination) { destination.onSubscribe(new InnerSubscription()); source.subscribe(new InnerSubscriber(destination)); + increaseInternalLimit(prefetch); } public void increaseRequestLimit(long n) { @@ -70,6 +74,14 @@ public void increaseRequestLimit(long n) { requestN(); } + public void increaseInternalLimit(long n) { + synchronized (this) { + internalRequested = Operators.addCap(n, internalRequested); + } + + requestN(); + } + @Override public void request(long n) { increaseRequestLimit(n); @@ -82,9 +94,17 @@ private void requestN() { return; } - r = Math.min(internalRequested, externalRequested); - externalRequested -= r; - internalRequested -= r; + if (externalRequested != Long.MAX_VALUE || internalRequested != Long.MAX_VALUE) { + r = Math.min(internalRequested, externalRequested); + if (externalRequested != Long.MAX_VALUE) { + externalRequested -= r; + } + if (internalRequested != Long.MAX_VALUE) { + internalRequested -= r; + } + } else { + r = Long.MAX_VALUE; + } } if (r > 0) { @@ -144,13 +164,7 @@ public void onComplete() { private class InnerSubscription implements Subscription { @Override - public void request(long n) { - synchronized (LimitableRequestPublisher.this) { - internalRequested = Operators.addCap(n, internalRequested); - } - - requestN(); - } + public void request(long n) {} @Override public void cancel() { diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index bcfa77287..b1e3e5e55 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -17,7 +17,6 @@ package io.rsocket.internal; import io.netty.util.ReferenceCountUtil; -import io.netty.util.internal.shaded.org.jctools.queues.MpscUnboundedArrayQueue; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -221,6 +220,10 @@ public void onSubscribe(Subscription s) { } } + public long available() { + return requested; + } + @Override public int getPrefetch() { return Integer.MAX_VALUE; diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java index c64a124f9..49d11e8c4 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java @@ -33,12 +33,15 @@ import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.RequestFrameFlyweight; import io.rsocket.framing.FrameType; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.stream.Collectors; import org.assertj.core.api.Assertions; import org.junit.Rule; @@ -215,6 +218,28 @@ public void testChannelRequestServerSideCancellation() { Assertions.assertThat(request.isDisposed()).isTrue(); } + @Test(timeout = 2_000) + @SuppressWarnings("unchecked") + public void + testClientSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() { + final Queue requests = new ConcurrentLinkedQueue<>(); + rule.connection.dispose(); + rule.connection = new TestDuplexConnection(); + rule.connection.setInitialSendRequestN(256); + rule.init(); + + rule.socket + .requestChannel( + Flux.generate(s -> s.next(EmptyPayload.INSTANCE)).doOnRequest(requests::add)) + .subscribe(); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + + rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, 2)); + rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, Integer.MAX_VALUE)); + Assertions.assertThat(requests).containsOnly(1L, 2L, 253L); + } + public int sendRequestResponse(Publisher response) { Subscriber sub = TestSubscriber.create(); response.subscribe(sub); diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java index db1ca2d65..425c2dac9 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java @@ -29,12 +29,17 @@ import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; import java.util.Collection; +import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; +import org.assertj.core.api.Assertions; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class RSocketServerTest { @@ -103,6 +108,87 @@ public Mono requestResponse(Payload payload) { assertThat("Subscription not cancelled.", cancelled.get(), is(true)); } + @Test(timeout = 2_000) + @SuppressWarnings("unchecked") + public void + testServerSideRequestStreamShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() { + final int streamId = 5; + final Queue received = new ConcurrentLinkedQueue<>(); + final Queue requests = new ConcurrentLinkedQueue<>(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.generate(s -> s.next(payload.retain())).doOnRequest(requests::add); + } + }, + 256); + + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + + assertThat("Unexpected error.", rule.errors, is(empty())); + + Subscriber next = rule.connection.getSendSubscribers().iterator().next(); + + Mockito.doAnswer( + invocation -> { + received.add(invocation.getArgument(0)); + + if (received.size() == 256) { + throw new RuntimeException(); + } + + return null; + }) + .when(next) + .onNext(Mockito.any()); + + rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, Integer.MAX_VALUE)); + Assertions.assertThat(requests).containsOnly(1L, 2L, 253L); + } + + @Test(timeout = 2_000) + @SuppressWarnings("unchecked") + public void + testServerSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() { + final int streamId = 5; + final Queue received = new ConcurrentLinkedQueue<>(); + final Queue requests = new ConcurrentLinkedQueue<>(); + + rule.setAcceptingSocket( + new AbstractRSocket() { + @Override + public Flux requestChannel(Publisher payload) { + return Flux.generate(s -> s.next(EmptyPayload.INSTANCE)) + .doOnRequest(requests::add); + } + }, + 256); + + rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL); + + assertThat("Unexpected error.", rule.errors, is(empty())); + + Subscriber next = rule.connection.getSendSubscribers().iterator().next(); + + Mockito.doAnswer( + invocation -> { + received.add(invocation.getArgument(0)); + + if (received.size() == 256) { + throw new RuntimeException(); + } + + return null; + }) + .when(next) + .onNext(Mockito.any()); + + rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, Integer.MAX_VALUE)); + Assertions.assertThat(requests).containsOnly(1L, 2L, 253L); + } + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; @@ -127,6 +213,15 @@ public void setAcceptingSocket(RSocket acceptingSocket) { super.init(); } + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(); + connection.setInitialSendRequestN(prefetch); + connectSub = TestSubscriber.create(); + errors = new ConcurrentLinkedQueue<>(); + super.init(); + } + @Override protected RSocketServer newRSocket() { return new RSocketServer( diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index f5d048508..41e437fee 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -67,7 +67,7 @@ public void cleanup() { server.dispose(); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testCompleteWithoutNext() { handler = new AbstractRSocket() { @@ -83,7 +83,7 @@ public Flux requestStream(Payload payload) { assertFalse(hasElements); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testSingleStream() { handler = new AbstractRSocket() { @@ -100,7 +100,7 @@ public Flux requestStream(Payload payload) { assertEquals("RESPONSE", result.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testZeroPayload() { handler = new AbstractRSocket() { @@ -117,7 +117,7 @@ public Flux requestStream(Payload payload) { assertEquals("", result.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testRequestResponseErrors() { handler = new AbstractRSocket() { @@ -151,7 +151,7 @@ public Mono requestResponse(Payload payload) { assertEquals("SUCCESS", response2.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testTwoConcurrentStreams() throws InterruptedException { ConcurrentHashMap> map = new ConcurrentHashMap<>(); UnicastProcessor processor1 = UnicastProcessor.create();