Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions rsocket-core/src/main/java/io/rsocket/RSocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ class RSocketClient implements RSocket {
connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);

connection
.send(sendProcessor)
.send(
sendProcessor.doOnRequest(
r -> {
for (LimitableRequestPublisher lrp : senders.values()) {
lrp.increaseInternalLimit(r);
}
}))
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);

Expand Down Expand Up @@ -294,7 +300,8 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
.transform(
f -> {
LimitableRequestPublisher<Payload> wrapped =
LimitableRequestPublisher.wrap(f);
LimitableRequestPublisher.wrap(
f, sendProcessor.available());
// Need to set this to one for first the frame
wrapped.increaseRequestLimit(1);
senders.put(streamId, wrapped);
Expand Down
52 changes: 47 additions & 5 deletions rsocket-core/src/main/java/io/rsocket/RSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RSocketServer implements ResponderRSocket {
private final Function<Frame, ? extends Payload> frameDecoder;
private final Consumer<Throwable> errorConsumer;

private final Map<Integer, LimitableRequestPublisher> sendingLimitableSubscriptions;
private final Map<Integer, Subscription> sendingSubscriptions;
private final Map<Integer, Processor<Payload, Payload>> channelProcessors;

Expand Down Expand Up @@ -81,6 +82,7 @@ class RSocketServer implements ResponderRSocket {
this.connection = connection;
this.frameDecoder = frameDecoder;
this.errorConsumer = errorConsumer;
this.sendingLimitableSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>());
this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>());

Expand All @@ -89,7 +91,13 @@ class RSocketServer implements ResponderRSocket {
this.sendProcessor = new UnboundedProcessor<>();

connection
.send(sendProcessor)
.send(
sendProcessor.doOnRequest(
r -> {
for (LimitableRequestPublisher lrp : sendingLimitableSubscriptions.values()) {
lrp.increaseInternalLimit(r);
}
}))
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);

Expand Down Expand Up @@ -135,6 +143,17 @@ private void handleSendProcessorError(Throwable t) {
}
});

sendingLimitableSubscriptions
.values()
.forEach(
subscription -> {
try {
subscription.cancel();
} catch (Throwable e) {
errorConsumer.accept(e);
}
});

channelProcessors
.values()
.forEach(
Expand Down Expand Up @@ -163,6 +182,17 @@ private void handleSendProcessorCancel(SignalType t) {
}
});

sendingLimitableSubscriptions
.values()
.forEach(
subscription -> {
try {
subscription.cancel();
} catch (Throwable e) {
errorConsumer.accept(e);
}
});

channelProcessors
.values()
.forEach(
Expand Down Expand Up @@ -258,6 +288,9 @@ private void cleanup() {
private synchronized void cleanUpSendingSubscriptions() {
sendingSubscriptions.values().forEach(Subscription::cancel);
sendingSubscriptions.clear();

sendingLimitableSubscriptions.values().forEach(Subscription::cancel);
sendingLimitableSubscriptions.clear();
}

private synchronized void cleanUpChannelProcessors() {
Expand Down Expand Up @@ -373,12 +406,12 @@ private void handleStream(int streamId, Flux<Payload> response, int initialReque
.transform(
frameFlux -> {
LimitableRequestPublisher<Payload> payloads =
LimitableRequestPublisher.wrap(frameFlux);
sendingSubscriptions.put(streamId, payloads);
LimitableRequestPublisher.wrap(frameFlux, sendProcessor.available());
sendingLimitableSubscriptions.put(streamId, payloads);
payloads.increaseRequestLimit(initialRequestN);
return payloads;
})
.doFinally(signalType -> sendingSubscriptions.remove(streamId))
.doFinally(signalType -> sendingLimitableSubscriptions.remove(streamId))
.subscribe(
payload -> {
final Frame frame = Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload);
Expand Down Expand Up @@ -423,6 +456,11 @@ private void handleKeepAliveFrame(Frame frame) {

private void handleCancelFrame(int streamId) {
Subscription subscription = sendingSubscriptions.remove(streamId);

if (subscription == null) {
subscription = sendingLimitableSubscriptions.get(streamId);
}

if (subscription != null) {
subscription.cancel();
}
Expand All @@ -434,7 +472,11 @@ private void handleError(int streamId, Throwable t) {
}

private void handleRequestN(int streamId, Frame frame) {
final Subscription subscription = sendingSubscriptions.get(streamId);
Subscription subscription = sendingSubscriptions.get(streamId);

if (subscription == null) {
subscription = sendingLimitableSubscriptions.get(streamId);
}
if (subscription != null) {
int n = Frame.RequestN.requestN(frame);
subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio

private final AtomicBoolean canceled;

private final long prefetch;

private long internalRequested;

private long externalRequested;
Expand All @@ -39,13 +41,14 @@ public class LimitableRequestPublisher<T> extends Flux<T> implements Subscriptio

private volatile @Nullable Subscription internalSubscription;

private LimitableRequestPublisher(Publisher<T> source) {
private LimitableRequestPublisher(Publisher<T> source, long prefetch) {
this.source = source;
this.prefetch = prefetch;
this.canceled = new AtomicBoolean();
}

public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source) {
return new LimitableRequestPublisher<>(source);
public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source, long prefetch) {
return new LimitableRequestPublisher<>(source, prefetch);
}

@Override
Expand All @@ -60,6 +63,7 @@ public void subscribe(CoreSubscriber<? super T> destination) {

destination.onSubscribe(new InnerSubscription());
source.subscribe(new InnerSubscriber(destination));
increaseInternalLimit(prefetch);
}

public void increaseRequestLimit(long n) {
Expand All @@ -70,6 +74,14 @@ public void increaseRequestLimit(long n) {
requestN();
}

public void increaseInternalLimit(long n) {
synchronized (this) {
internalRequested = Operators.addCap(n, internalRequested);
}

requestN();
}

@Override
public void request(long n) {
increaseRequestLimit(n);
Expand All @@ -82,9 +94,17 @@ private void requestN() {
return;
}

r = Math.min(internalRequested, externalRequested);
externalRequested -= r;
internalRequested -= r;
if (externalRequested != Long.MAX_VALUE || internalRequested != Long.MAX_VALUE) {
r = Math.min(internalRequested, externalRequested);
if (externalRequested != Long.MAX_VALUE) {
externalRequested -= r;
}
if (internalRequested != Long.MAX_VALUE) {
internalRequested -= r;
}
} else {
r = Long.MAX_VALUE;
}
}

if (r > 0) {
Expand Down Expand Up @@ -144,13 +164,7 @@ public void onComplete() {

private class InnerSubscription implements Subscription {
@Override
public void request(long n) {
synchronized (LimitableRequestPublisher.this) {
internalRequested = Operators.addCap(n, internalRequested);
}

requestN();
}
public void request(long n) {}

@Override
public void cancel() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package io.rsocket.internal;

import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.shaded.org.jctools.queues.MpscUnboundedArrayQueue;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
Expand Down Expand Up @@ -221,6 +220,10 @@ public void onSubscribe(Subscription s) {
}
}

public long available() {
return requested;
}

@Override
public int getPrefetch() {
return Integer.MAX_VALUE;
Expand Down
25 changes: 25 additions & 0 deletions rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
import io.rsocket.exceptions.RejectedSetupException;
import io.rsocket.frame.RequestFrameFlyweight;
import io.rsocket.framing.FrameType;
import io.rsocket.test.util.TestDuplexConnection;
import io.rsocket.test.util.TestSubscriber;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.Rule;
Expand Down Expand Up @@ -215,6 +218,28 @@ public void testChannelRequestServerSideCancellation() {
Assertions.assertThat(request.isDisposed()).isTrue();
}

@Test(timeout = 2_000)
@SuppressWarnings("unchecked")
public void
testClientSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
final Queue<Long> requests = new ConcurrentLinkedQueue<>();
rule.connection.dispose();
rule.connection = new TestDuplexConnection();
rule.connection.setInitialSendRequestN(256);
rule.init();

rule.socket
.requestChannel(
Flux.<Payload>generate(s -> s.next(EmptyPayload.INSTANCE)).doOnRequest(requests::add))
.subscribe();

int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);

rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, 2));
rule.connection.addToReceivedBuffer(Frame.RequestN.from(streamId, Integer.MAX_VALUE));
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
}

public int sendRequestResponse(Publisher<Payload> response) {
Subscriber<Payload> sub = TestSubscriber.create();
response.subscribe(sub);
Expand Down
Loading