Skip to content

Commit

Permalink
Propagate ThreadLocals for non-Reactor upstream sources (#3418)
Browse files Browse the repository at this point in the history
Factory methods for creating `Flux` and `Mono` from non-Reactor sources
now restore `ThreadLocal` values when
`Hooks.enableAutomaticContextPropagation()`was called in the following cases:

* `Flux.from(Publisher)`
* `Mono.from(Publisher)`
* `Mono.fromDirect(Publisher)`
* `Mono.fromFuture(CompletableFuture)`
* `Mono.fromCompletionStage(CompletionStage)`

and relevant overrides.

Fixes #3366.
  • Loading branch information
chemicL committed Apr 3, 2023
1 parent 74c954e commit 28ae2a0
Show file tree
Hide file tree
Showing 5 changed files with 504 additions and 25 deletions.
102 changes: 100 additions & 2 deletions reactor-core/src/main/java/reactor/core/publisher/FluxSource.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2021 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2023 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,12 +18,16 @@

import java.util.Objects;

import io.micrometer.context.ContextSnapshot;
import org.reactivestreams.Publisher;

import org.reactivestreams.Subscription;
import reactor.core.CorePublisher;
import reactor.core.CoreSubscriber;
import reactor.core.Fuseable;
import reactor.core.Scannable;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;

/**
* A connecting {@link Flux} Publisher (right-to-left from a composition chain perspective)
Expand Down Expand Up @@ -64,7 +68,11 @@ final class FluxSource<I> extends Flux<I> implements SourceProducer<I>,
@Override
@SuppressWarnings("unchecked")
public void subscribe(CoreSubscriber<? super I> actual) {
source.subscribe(actual);
if (ContextPropagation.shouldPropagateContextToThreadLocals()) {
source.subscribe(new FluxSourceRestoringThreadLocalsSubscriber<>(actual));
} else {
source.subscribe(actual);
}
}

@Override
Expand All @@ -91,4 +99,94 @@ public Object scanUnsafe(Attr key) {
return null;
}

static final class FluxSourceRestoringThreadLocalsSubscriber<T>
implements Fuseable.ConditionalSubscriber<T>, InnerConsumer<T> {

final CoreSubscriber<? super T> actual;
final Fuseable.ConditionalSubscriber<? super T> actualConditional;

Subscription s;

@SuppressWarnings("unchecked")
FluxSourceRestoringThreadLocalsSubscriber(CoreSubscriber<? super T> actual) {
this.actual = actual;
if (actual instanceof Fuseable.ConditionalSubscriber) {
this.actualConditional = (Fuseable.ConditionalSubscriber<? super T>) actual;
}
else {
this.actualConditional = null;
}
}

@Override
@Nullable
public Object scanUnsafe(Attr key) {
if (key == Attr.PARENT) {
return s;
}
if (key == Attr.RUN_STYLE) {
return Attr.RunStyle.SYNC;
}
if (key == Attr.ACTUAL) {
return actual;
}
return null;
}

@Override
public Context currentContext() {
return actual.currentContext();
}

@SuppressWarnings("try")
@Override
public void onSubscribe(Subscription s) {
// This is needed, as the downstream can then switch threads,
// continue the subscription using different primitives and omit this operator
try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
actual.onSubscribe(s);
}
}

@SuppressWarnings("try")
@Override
public void onNext(T t) {
try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
actual.onNext(t);
}
}

@SuppressWarnings("try")
@Override
public boolean tryOnNext(T t) {
try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
if (actualConditional != null) {
return actualConditional.tryOnNext(t);
}
actual.onNext(t);
return true;
}
}

@SuppressWarnings("try")
@Override
public void onError(Throwable t) {
try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
actual.onError(t);
}
}

@SuppressWarnings("try")
@Override
public void onComplete() {
try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
actual.onComplete();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2022 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2023 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,9 @@
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.BiFunction;

import io.micrometer.context.ContextSnapshot;
import reactor.core.CoreSubscriber;
import reactor.core.Exceptions;
import reactor.core.Fuseable;
import reactor.core.Scannable;
import reactor.util.annotation.Nullable;
import reactor.util.context.Context;
Expand All @@ -40,7 +40,7 @@
* @param <T> the value type
*/
final class MonoCompletionStage<T> extends Mono<T>
implements Fuseable, Scannable {
implements Scannable {

final CompletionStage<? extends T> future;
final boolean suppressCancellation;
Expand All @@ -52,7 +52,14 @@ final class MonoCompletionStage<T> extends Mono<T>

@Override
public void subscribe(CoreSubscriber<? super T> actual) {
actual.onSubscribe(new MonoCompletionStageSubscription<>(actual, future, suppressCancellation));
if (ContextPropagation.shouldPropagateContextToThreadLocals()) {
actual.onSubscribe(
new MonoCompletionStageRestoringThreadLocalsSubscription<>(
actual, future, suppressCancellation));
} else {
actual.onSubscribe(new MonoCompletionStageSubscription<>(
actual, future, suppressCancellation));
}
}

@Override
Expand All @@ -62,8 +69,6 @@ public Object scanUnsafe(Attr key) {
}

static class MonoCompletionStageSubscription<T> implements InnerProducer<T>,
Fuseable,
QueueSubscription<T>,
BiFunction<T, Throwable, Void> {

final CoreSubscriber<? super T> actual;
Expand Down Expand Up @@ -154,29 +159,104 @@ public void cancel() {
((Future<? extends T>) future).cancel(true);
}
}
}

@Override
public int requestFusion(int requestedMode) {
return NONE;
static class MonoCompletionStageRestoringThreadLocalsSubscription<T>
implements InnerProducer<T>, BiFunction<T, Throwable, Void> {

final CoreSubscriber<? super T> actual;
final CompletionStage<? extends T> future;
final boolean suppressCancellation;

volatile int requestedOnce;
@SuppressWarnings("rawtypes")
static final AtomicIntegerFieldUpdater<MonoCompletionStageRestoringThreadLocalsSubscription> REQUESTED_ONCE =
AtomicIntegerFieldUpdater.newUpdater(MonoCompletionStageRestoringThreadLocalsSubscription.class, "requestedOnce");

volatile boolean cancelled;

MonoCompletionStageRestoringThreadLocalsSubscription(
CoreSubscriber<? super T> actual,
CompletionStage<? extends T> future,
boolean suppressCancellation) {
this.actual = actual;
this.future = future;
this.suppressCancellation = suppressCancellation;
}

@Override
public T poll() {
return null;
public CoreSubscriber<? super T> actual() {
return this.actual;
}

@Override
public int size() {
return 0;
public Void apply(@Nullable T value, @Nullable Throwable e) {
final CoreSubscriber<? super T> actual = this.actual;

try (ContextSnapshot.Scope ignored =
ContextPropagation.setThreadLocals(actual.currentContext())) {
if (this.cancelled) {
//nobody is interested in the Mono anymore, don't risk dropping errors
final Context ctx = actual.currentContext();
if (e == null || e instanceof CancellationException) {
//we discard any potential value and ignore Future cancellations
Operators.onDiscard(value, ctx);
}
else {
//we make sure we keep _some_ track of a Future failure AFTER the Mono cancellation
Operators.onErrorDropped(e, ctx);
//and we discard any potential value just in case both e and v are not null
Operators.onDiscard(value, ctx);
}

return null;
}

try {
if (e instanceof CompletionException) {
actual.onError(e.getCause());
}
else if (e != null) {
actual.onError(e);
}
else if (value != null) {
actual.onNext(value);
actual.onComplete();
}
else {
actual.onComplete();
}
}
catch (Throwable e1) {
Operators.onErrorDropped(e1, actual.currentContext());
throw Exceptions.bubble(e1);
}
return null;
}
}

@Override
public boolean isEmpty() {
return true;
public void request(long n) {
if (this.cancelled) {
return;
}

if (this.requestedOnce == 1 || !REQUESTED_ONCE.compareAndSet(this, 0 , 1)) {
return;
}

future.handle(this);
}

@Override
public void clear() {
public void cancel() {
this.cancelled = true;

final CompletionStage<? extends T> future = this.future;
if (!suppressCancellation && future instanceof Future) {
//noinspection unchecked
((Future<? extends T>) future).cancel(true);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2016-2021 VMware Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2016-2023 VMware Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,6 +55,10 @@ final class MonoFromPublisher<T> extends Mono<T> implements Scannable,
@Override
@SuppressWarnings("unchecked")
public void subscribe(CoreSubscriber<? super T> actual) {
if (ContextPropagation.shouldPropagateContextToThreadLocals()) {
actual = new MonoSource.MonoSourceRestoringThreadLocalsSubscriber<>(actual);
}

try {
CoreSubscriber<? super T> subscriber = subscribeOrReturn(actual);
if (subscriber == null) {
Expand Down

0 comments on commit 28ae2a0

Please sign in to comment.