diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransform.java b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransform.java new file mode 100644 index 000000000..ebb496307 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransform.java @@ -0,0 +1,81 @@ +package io.rsocket.internal; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.DirectProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.BiFunction; + +public final class SwitchTransform extends Flux { + + final Publisher source; + final BiFunction, Publisher> transformer; + + public SwitchTransform(Publisher source, BiFunction, Publisher> transformer) { + this.source = Objects.requireNonNull(source, "source"); + this.transformer = Objects.requireNonNull(transformer, "transformer"); + } + + @Override + public void subscribe(CoreSubscriber actual) { + source.subscribe(new SwitchTransformSubscriber<>(actual, transformer)); + } + + static final class SwitchTransformSubscriber implements CoreSubscriber { + final CoreSubscriber actual; + final BiFunction, Publisher> transformer; + final DirectProcessor processor = DirectProcessor.create(); + + Subscription s; + + volatile int once; + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ONCE = + AtomicIntegerFieldUpdater.newUpdater(SwitchTransformSubscriber.class, "once"); + + SwitchTransformSubscriber(CoreSubscriber actual, BiFunction, Publisher> transformer) { + this.actual = actual; + this.transformer = transformer; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + + processor.onSubscribe(s); + } + } + + @Override + public void onNext(T t) { + if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + try { + Publisher result = Objects.requireNonNull(transformer.apply(t, processor), + "The transformer returned a null value"); + result.subscribe(actual); + } + catch (Throwable e) { + onError(Operators.onOperatorError(s, e, t, actual.currentContext())); + return; + } + } + processor.onNext(t); + } + + @Override + public void onError(Throwable t) { + processor.onError(t); + } + + @Override + public void onComplete() { + processor.onComplete(); + } + } +}