diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java index 55866ed92..b8ec5b863 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java @@ -16,7 +16,6 @@ package io.rsocket.internal; -import io.netty.util.ReferenceCountUtil; import java.util.Objects; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -122,11 +121,15 @@ public void cancel() { if (s != Operators.cancelledSubscription()) { Subscription s = this.s; this.s = Operators.cancelledSubscription(); - ReferenceCountUtil.safeRelease(first); if (WIP.getAndIncrement(this) == 0) { INNER.lazySet(this, null); - first = null; + + T f = first; + if (f != null) { + first = null; + Operators.onDiscard(f, currentContext()); + } } s.cancel(); @@ -171,7 +174,6 @@ public void onNext(T t) { return; } catch (Throwable e) { onError(Operators.onOperatorError(s, e, t, currentContext())); - ReferenceCountUtil.safeRelease(t); return; } } @@ -219,13 +221,14 @@ public void onComplete() { @Override public void request(long n) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - n = Operators.addCap(n, -1); - if (n > 0) { + if (Operators.validate(n)) { + if (first != null && drainRegular() && n != Long.MAX_VALUE) { + if (--n > 0) { + s.request(n); + } + } else { s.request(n); } - } else { - s.request(n); } } @@ -245,12 +248,11 @@ boolean drainRegular() { first = null; if (s == Operators.cancelledSubscription()) { - Operators.onNextDropped(f, a.currentContext()); + Operators.onDiscard(f, a.currentContext()); return true; } a.onNext(f); - ReferenceCountUtil.safeRelease(f); f = null; sent = true; } @@ -345,11 +347,15 @@ public void cancel() { if (s != Operators.cancelledSubscription()) { Subscription s = this.s; this.s = Operators.cancelledSubscription(); - ReferenceCountUtil.safeRelease(first); if (WIP.getAndIncrement(this) == 0) { INNER.lazySet(this, null); - first = null; + + T f = first; + if (f != null) { + first = null; + Operators.onDiscard(f, currentContext()); + } } s.cancel(); @@ -399,7 +405,6 @@ public void onNext(T t) { return; } catch (Throwable e) { onError(Operators.onOperatorError(s, e, t, currentContext())); - ReferenceCountUtil.safeRelease(t); return; } } @@ -426,7 +431,6 @@ public boolean tryOnNext(T t) { return true; } catch (Throwable e) { onError(Operators.onOperatorError(s, e, t, currentContext())); - ReferenceCountUtil.safeRelease(t); return false; } } @@ -474,12 +478,14 @@ public void onComplete() { @Override public void request(long n) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - if (--n > 0) { + if (Operators.validate(n)) { + if (first != null && drainRegular() && n != Long.MAX_VALUE) { + if (--n > 0) { + s.request(n); + } + } else { s.request(n); } - } else { - s.request(n); } } @@ -499,12 +505,11 @@ boolean drainRegular() { first = null; if (s == Operators.cancelledSubscription()) { - Operators.onNextDropped(f, a.currentContext()); + Operators.onDiscard(f, a.currentContext()); return true; } a.onNext(f); - ReferenceCountUtil.safeRelease(f); f = null; sent = true; } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java index e4b897409..2297d6bfa 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java @@ -394,4 +394,51 @@ public void shouldBeAbleToBeCancelledProperly() { publisher.assertCancelled(); publisher.assertWasRequested(); } + + @Test + public void shouldBeAbleToCatchDiscardedElement() { + TestPublisher publisher = TestPublisher.createCold(); + Integer[] discarded = new Integer[1]; + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) + .doOnDiscard(Integer.class, e -> discarded[0] = e); + + publisher.next(1); + + StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + + Assert.assertArrayEquals(new Integer[] {1}, discarded); + } + + @Test + public void shouldBeAbleToCatchDiscardedElementInCaseOfConditional() { + TestPublisher publisher = TestPublisher.createCold(); + Integer[] discarded = new Integer[1]; + Flux switchTransformed = + publisher + .flux() + .transform( + flux -> + new SwitchTransformFlux<>( + flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) + .filter(t -> true) + .doOnDiscard(Integer.class, e -> discarded[0] = e); + + publisher.next(1); + + StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); + + publisher.assertCancelled(); + publisher.assertWasRequested(); + + Assert.assertArrayEquals(new Integer[] {1}, discarded); + } }