Skip to content

Commit

Permalink
Enable use of single, specific ContextSnapshotFactory instance
Browse files Browse the repository at this point in the history
Closes gh-919
  • Loading branch information
rstoyanchev committed Apr 12, 2024
1 parent d9a583a commit 2652a57
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 39 deletions.
Expand Up @@ -26,12 +26,12 @@
import java.util.concurrent.Executor;

import graphql.GraphQLContext;
import io.micrometer.context.ContextSnapshotFactory;
import reactor.core.publisher.Mono;

import org.springframework.core.CoroutinesUtils;
import org.springframework.core.KotlinDetector;
import org.springframework.data.util.KotlinReflectionUtils;
import org.springframework.graphql.execution.ContextSnapshotFactoryHelper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

Expand All @@ -46,8 +46,6 @@ public abstract class InvocableHandlerMethodSupport extends HandlerMethod {

private static final Object NO_VALUE = new Object();

private static final ContextSnapshotFactory SNAPSHOT_FACTORY = ContextSnapshotFactory.builder().build();


private final boolean hasCallableReturnValue;

Expand Down Expand Up @@ -131,7 +129,7 @@ private Object handleReturnValue(GraphQLContext graphQLContext, @Nullable Object
return CompletableFuture.supplyAsync(
() -> {
try {
return SNAPSHOT_FACTORY.captureFrom(graphQLContext).wrap((Callable<?>) result).call();
return ContextSnapshotFactoryHelper.captureFrom(graphQLContext).wrap((Callable<?>) result).call();
}
catch (Exception ex) {
throw new IllegalStateException(
Expand Down
Expand Up @@ -59,7 +59,6 @@ final class ContextDataFetcherDecorator implements DataFetcher<Object> {

private final SubscriptionExceptionResolver subscriptionExceptionResolver;

private final ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();

private ContextDataFetcherDecorator(
DataFetcher<?> delegate, boolean subscription,
Expand All @@ -72,17 +71,18 @@ private ContextDataFetcherDecorator(
this.subscriptionExceptionResolver = subscriptionExceptionResolver;
}


@Override
public Object get(DataFetchingEnvironment environment) throws Exception {
public Object get(DataFetchingEnvironment env) throws Exception {

ContextSnapshot snapshot;
if (environment.getLocalContext() instanceof GraphQLContext localContext) {
snapshot = this.snapshotFactory.captureFrom(environment.getGraphQlContext(), localContext);
}
else {
snapshot = this.snapshotFactory.captureFrom(environment.getGraphQlContext());
}
Object value = snapshot.wrap(() -> this.delegate.get(environment)).call();
GraphQLContext graphQlContext = env.getGraphQlContext();
ContextSnapshotFactory snapshotFactory = ContextSnapshotFactoryHelper.getInstance(graphQlContext);

ContextSnapshot snapshot = (env.getLocalContext() instanceof GraphQLContext localContext) ?
snapshotFactory.captureFrom(graphQlContext, localContext) :
snapshotFactory.captureFrom(graphQlContext);

Object value = snapshot.wrap(() -> this.delegate.get(env)).call();

if (this.subscription) {
Assert.state(value instanceof Publisher, "Expected Publisher for a subscription");
Expand Down
@@ -0,0 +1,117 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.graphql.execution;

import graphql.GraphQLContext;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshotFactory;
import reactor.util.context.Context;
import reactor.util.context.ContextView;

import org.springframework.lang.Nullable;

/**
* Helper to use a single {@link ContextSnapshotFactory} instance by saving and
* obtaining it to and from Reactor and GraphQL contexts.
*
* @author Rossen Stoyanchev
* @since 1.3
*/
public abstract class ContextSnapshotFactoryHelper {

private static final ContextSnapshotFactory sharedInstance = ContextSnapshotFactory.builder().build();

private static final String CONTEXT_SNAPSHOT_FACTORY_KEY = ContextSnapshotFactoryHelper.class.getName() + ".KEY";


/**
* Select a {@code ContextSnapshotFactory} instance to use, either the one
* passed in if it is not {@code null}, or a shared, static instance.
* @param factory the candidate factory instance to use if not {@code null}
* @return the instance to use
*/
public static ContextSnapshotFactory selectInstance(@Nullable ContextSnapshotFactory factory) {
if (factory != null) {
return factory;
}
return sharedInstance;
}

/**
* Save the {@code ContextSnapshotFactory} in the given {@link Context}.
* @param factory the instance to save
* @param context the context to save the instance to
* @return a new context with the saved instance
*/
public static Context saveInstance(ContextSnapshotFactory factory, Context context) {
return context.put(CONTEXT_SNAPSHOT_FACTORY_KEY, factory);
}

/**
* Save the {@code ContextSnapshotFactory} in the given {@link Context}.
* @param factory the instance to save
* @param context the context to save the instance to
*/
public static void saveInstance(ContextSnapshotFactory factory, GraphQLContext context) {
context.put(CONTEXT_SNAPSHOT_FACTORY_KEY, factory);
}

/**
* Access the {@code ContextSnapshotFactory} from the given {@link ContextView}
* or return a shared, static instance.
* @param contextView the context where the instance is saved
* @return the instance to use
*/
public static ContextSnapshotFactory getInstance(ContextView contextView) {
ContextSnapshotFactory factory = contextView.getOrDefault(CONTEXT_SNAPSHOT_FACTORY_KEY, null);
return selectInstance(factory);
}

/**
* Access the {@code ContextSnapshotFactory} from the given {@link GraphQLContext}
* or return a shared, static instance.
* @param context the context where the instance is saved
* @return the instance to use
*/
public static ContextSnapshotFactory getInstance(GraphQLContext context) {
ContextSnapshotFactory factory = context.get(CONTEXT_SNAPSHOT_FACTORY_KEY);
return selectInstance(factory);
}

/**
* Shortcut to obtain the {@code ContextSnapshotFactory} instance, and to
* capture from the given {@link ContextView}.
* @param contextView the context to capture from
* @return a snapshot from the capture
*/
public static ContextSnapshot captureFrom(ContextView contextView) {
ContextSnapshotFactory factory = getInstance(contextView);
return selectInstance(factory).captureFrom(contextView);
}

/**
* Shortcut to obtain the {@code ContextSnapshotFactory} instance, and to
* capture from the given {@link GraphQLContext}.
* @param context the context to capture from
* @return a snapshot from the capture
*/
public static ContextSnapshot captureFrom(GraphQLContext context) {
ContextSnapshotFactory factory = getInstance(context);
return selectInstance(factory).captureFrom(context);
}

}
Expand Up @@ -22,7 +22,6 @@

import graphql.GraphQLError;
import graphql.schema.DataFetchingEnvironment;
import io.micrometer.context.ContextSnapshotFactory;
import io.micrometer.context.ThreadLocalAccessor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -53,8 +52,6 @@ public abstract class DataFetcherExceptionResolverAdapter implements DataFetcher

protected final Log logger = LogFactory.getLog(getClass());

protected final ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();

private boolean threadLocalContextAware;


Expand Down Expand Up @@ -101,7 +98,7 @@ private List<GraphQLError> resolveInternal(Throwable exception, DataFetchingEnvi
return resolveToMultipleErrors(exception, env);
}
try {
return this.snapshotFactory.captureFrom(env.getGraphQlContext())
return ContextSnapshotFactoryHelper.captureFrom(env.getGraphQlContext())
.wrap(() -> resolveToMultipleErrors(exception, env))
.call();
}
Expand Down
Expand Up @@ -28,7 +28,6 @@

import graphql.GraphQLContext;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshotFactory;
import org.dataloader.BatchLoaderContextProvider;
import org.dataloader.BatchLoaderEnvironment;
import org.dataloader.BatchLoaderWithContext;
Expand All @@ -54,8 +53,6 @@
*/
public class DefaultBatchLoaderRegistry implements BatchLoaderRegistry {

private static final ContextSnapshotFactory SNAPSHOT_FACTORY = ContextSnapshotFactory.builder().build();

private final List<ReactorBatchLoader<?, ?>> loaders = new ArrayList<>();

private final List<ReactorMappedBatchLoader<?, ?>> mappedLoaders = new ArrayList<>();
Expand Down Expand Up @@ -231,7 +228,7 @@ DataLoaderOptions getOptions() {
@Override
public CompletionStage<List<V>> load(List<K> keys, BatchLoaderEnvironment environment) {
GraphQLContext graphQLContext = environment.getContext();
ContextSnapshot snapshot = SNAPSHOT_FACTORY.captureFrom(graphQLContext);
ContextSnapshot snapshot = ContextSnapshotFactoryHelper.captureFrom(graphQLContext);
try {
return snapshot.wrap(() ->
this.loader.apply(keys, environment)
Expand Down Expand Up @@ -279,7 +276,7 @@ DataLoaderOptions getOptions() {
@Override
public CompletionStage<Map<K, V>> load(Set<K> keys, BatchLoaderEnvironment environment) {
GraphQLContext graphQLContext = environment.getContext();
ContextSnapshot snapshot = SNAPSHOT_FACTORY.captureFrom(graphQLContext);
ContextSnapshot snapshot = ContextSnapshotFactoryHelper.captureFrom(graphQLContext);
try {
return snapshot.wrap(() ->
this.loader.apply(keys, environment)
Expand Down
Expand Up @@ -46,7 +46,6 @@ public class DefaultExecutionGraphQlService implements ExecutionGraphQlService {
private static final BiFunction<ExecutionInput, ExecutionInput.Builder, ExecutionInput> RESET_EXECUTION_ID_CONFIGURER =
(executionInput, builder) -> builder.executionId(null).build();

private final ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();

private final GraphQlSource graphQlSource;

Expand Down Expand Up @@ -90,7 +89,10 @@ public final Mono<ExecutionGraphQlResponse> execute(ExecutionGraphQlRequest requ

ExecutionInput executionInput = request.toExecutionInput();

this.snapshotFactory.captureFrom(contextView).updateContext(executionInput.getGraphQLContext());
ContextSnapshotFactory factory = ContextSnapshotFactoryHelper.getInstance(contextView);
GraphQLContext graphQLContext = executionInput.getGraphQLContext();
ContextSnapshotFactoryHelper.saveInstance(factory, graphQLContext);
factory.captureFrom(contextView).updateContext(graphQLContext);

ExecutionInput updatedExecutionInput =
(this.hasDataLoaderRegistrations ? registerDataLoaders(executionInput) : executionInput);
Expand Down
Expand Up @@ -29,7 +29,6 @@
import graphql.execution.ExecutionId;
import graphql.schema.DataFetchingEnvironment;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshotFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
Expand All @@ -47,8 +46,6 @@ class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler

private static final Log logger = LogFactory.getLog(ExceptionResolversExceptionHandler.class);

private final ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();

private final List<DataFetcherExceptionResolver> resolvers;

/**
Expand All @@ -65,7 +62,7 @@ class ExceptionResolversExceptionHandler implements DataFetcherExceptionHandler
public CompletableFuture<DataFetcherExceptionHandlerResult> handleException(DataFetcherExceptionHandlerParameters params) {
Throwable exception = unwrapException(params);
DataFetchingEnvironment env = params.getDataFetchingEnvironment();
ContextSnapshot snapshot = this.snapshotFactory.captureFrom(env.getGraphQlContext());
ContextSnapshot snapshot = ContextSnapshotFactoryHelper.captureFrom(env.getGraphQlContext());
try {
return Flux.fromIterable(this.resolvers)
.flatMap((resolver) -> resolver.resolveException(exception, env))
Expand Down
Expand Up @@ -22,7 +22,6 @@

import graphql.GraphQLError;
import io.micrometer.context.ContextSnapshot;
import io.micrometer.context.ContextSnapshotFactory;
import io.micrometer.context.ThreadLocalAccessor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -51,8 +50,6 @@ public abstract class SubscriptionExceptionResolverAdapter implements Subscripti

protected final Log logger = LogFactory.getLog(getClass());

protected final ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();

private boolean threadLocalContextAware;


Expand Down Expand Up @@ -86,7 +83,7 @@ public boolean isThreadLocalContextAware() {
public final Mono<List<GraphQLError>> resolveException(Throwable exception) {
if (this.threadLocalContextAware) {
return Mono.deferContextual((contextView) -> {
ContextSnapshot snapshot = this.snapshotFactory.captureFrom(contextView);
ContextSnapshot snapshot = ContextSnapshotFactoryHelper.captureFrom(contextView);
try {
List<GraphQLError> errors = snapshot.wrap(() -> resolveToMultipleErrors(exception)).call();
return Mono.justOrEmpty(errors);
Expand Down
Expand Up @@ -25,6 +25,7 @@
import reactor.core.publisher.Mono;

import org.springframework.graphql.ExecutionGraphQlService;
import org.springframework.graphql.execution.ContextSnapshotFactoryHelper;
import org.springframework.graphql.server.WebGraphQlInterceptor.Chain;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand All @@ -41,6 +42,9 @@ class DefaultWebGraphQlHandlerBuilder implements WebGraphQlHandler.Builder {

private final List<WebGraphQlInterceptor> interceptors = new ArrayList<>();

@Nullable
private ContextSnapshotFactory snapshotFactory;

@Nullable
private WebSocketGraphQlInterceptor webSocketInterceptor;

Expand Down Expand Up @@ -68,10 +72,16 @@ public WebGraphQlHandler.Builder interceptors(List<WebGraphQlInterceptor> interc
return this;
}

@Override
public WebGraphQlHandler.Builder contextSnapshotFactory(ContextSnapshotFactory snapshotFactory) {
this.snapshotFactory = snapshotFactory;
return this;
}

@Override
public WebGraphQlHandler build() {

ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().build();
ContextSnapshotFactory snapshotFactory = ContextSnapshotFactoryHelper.selectInstance(this.snapshotFactory);

Chain endOfChain = (request) -> this.service.execute(request).map(WebGraphQlResponse::new);

Expand All @@ -88,10 +98,18 @@ public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
DefaultWebGraphQlHandlerBuilder.this.webSocketInterceptor : new WebSocketGraphQlInterceptor() { };
}

@Override
public ContextSnapshotFactory contextSnapshotFactory() {
return snapshotFactory;
}

@Override
public Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request) {
ContextSnapshot snapshot = snapshotFactory.captureAll();
return executionChain.next(request).contextWrite(snapshot::updateContext);
return executionChain.next(request).contextWrite((context) -> {
context = ContextSnapshotFactoryHelper.saveInstance(snapshotFactory, context);
return snapshot.updateContext(context);
});
}
};
}
Expand Down

0 comments on commit 2652a57

Please sign in to comment.