diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index d88cfe445..38706fba2 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -25,7 +25,7 @@ import io.rsocket.framing.FrameType; /** - * Exposed to server for determination of RequestHandler based on mime types and SETUP metadata/data + * Exposed to server for determination of ResponderRSocket based on mime types and SETUP metadata/data */ public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java index e29d4d1f1..a226b5c06 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java @@ -27,7 +27,10 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.Disposable; -import reactor.core.publisher.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.UnicastProcessor; import java.util.Collections; import java.util.Map; @@ -39,15 +42,16 @@ import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; /** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */ -class RSocketServer implements RSocket { +class RSocketServer implements ResponderRSocket { private final DuplexConnection connection; private final RSocket requestHandler; + private final ResponderRSocket responderRSocket; private final Function frameDecoder; private final Consumer errorConsumer; private final Map sendingSubscriptions; - private final Map> channelProcessors; + private final Map> channelProcessors; private final UnboundedProcessor sendProcessor; private KeepAliveHandler keepAliveHandler; @@ -69,12 +73,16 @@ class RSocketServer implements RSocket { Consumer errorConsumer, long tickPeriod, long ackTimeout) { - this.connection = connection; + this.requestHandler = requestHandler; + this.responderRSocket = + (requestHandler instanceof ResponderRSocket) ? (ResponderRSocket) requestHandler : null; + + this.connection = connection; this.frameDecoder = frameDecoder; this.errorConsumer = errorConsumer; this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>()); - this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>()); + this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>()); // DO NOT Change the order here. The Send processor must be subscribed to before receiving // connections @@ -116,21 +124,27 @@ class RSocketServer implements RSocket { } private void handleSendProcessorError(Throwable t) { - sendingSubscriptions.values().forEach(subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + sendingSubscriptions + .values() + .forEach( + subscription -> { + try { + subscription.cancel(); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); - channelProcessors.values().forEach(subscription -> { - try { - subscription.onError(t); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + channelProcessors + .values() + .forEach( + subscription -> { + try { + subscription.onError(t); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); } private void handleSendProcessorCancel(SignalType t) { @@ -138,21 +152,27 @@ private void handleSendProcessorCancel(SignalType t) { return; } - sendingSubscriptions.values().forEach(subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + sendingSubscriptions + .values() + .forEach( + subscription -> { + try { + subscription.cancel(); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); - channelProcessors.values().forEach(subscription -> { - try { - subscription.onComplete(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + channelProcessors + .values() + .forEach( + subscription -> { + try { + subscription.onComplete(); + } catch (Throwable e) { + errorConsumer.accept(e); + } + }); } @Override @@ -191,6 +211,15 @@ public Flux requestChannel(Publisher payloads) { } } + @Override + public Flux requestChannel(Payload payload, Publisher payloads) { + try { + return responderRSocket.requestChannel(payload, payloads); + } catch (Throwable t) { + return Flux.error(t); + } + } + @Override public Mono metadataPush(Payload payload) { try { @@ -232,9 +261,7 @@ private synchronized void cleanUpSendingSubscriptions() { } private synchronized void cleanUpChannelProcessors() { - channelProcessors - .values() - .forEach(Processor::onComplete); + channelProcessors.values().forEach(Processor::onComplete); channelProcessors.clear(); } @@ -381,7 +408,11 @@ private void handleChannel(int streamId, Payload payload, int initialRequestN) { // and any later payload can be processed frames.onNext(payload); - handleStream(streamId, requestChannel(payloads), initialRequestN); + if (responderRSocket != null) { + handleStream(streamId, requestChannel(payload, payloads), initialRequestN); + } else { + handleStream(streamId, requestChannel(payloads), initialRequestN); + } } private void handleKeepAliveFrame(Frame frame) { diff --git a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java new file mode 100644 index 000000000..f98901472 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java @@ -0,0 +1,23 @@ +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +/** + * Extends the {@link RSocket} that allows an implementer to peek at the first request payload of a + * channel. + */ +public interface ResponderRSocket extends RSocket { + /** + * Implement this method to peak at the first payload of the incoming request stream without + * having to subscribe to Publish<Payload> payloads + * + * @param payload First payload in the stream - this is the same payload as the first payload in + * Publisher<Payload> payloads + * @param payloads Stream of request payloads. + * @return Stream of response payloads. + */ + default Flux requestChannel(Payload payload, Publisher payloads) { + return requestChannel(payloads); + } +}