diff --git a/pom.xml b/pom.xml index a0beb607ea..2e4aab02dc 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-GH-5064-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index fc88571622..8d73f0d660 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-GH-5064-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 595e5a4250..36ac14f77e 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-GH-5064-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MapRequestContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MapRequestContext.java index 6185c95db5..26f5408ebd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MapRequestContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MapRequestContext.java @@ -21,7 +21,6 @@ import java.util.stream.Stream; import com.mongodb.RequestContext; -import org.jspecify.annotations.Nullable; /** * A {@link Map}-based {@link RequestContext}. @@ -30,24 +29,17 @@ * @author Greg Turnquist * @since 4.0.0 */ -class MapRequestContext implements RequestContext { - - private final Map map; +record MapRequestContext(Map map) implements RequestContext { public MapRequestContext() { this(new HashMap<>()); } - public MapRequestContext(Map context) { - this.map = context; - } - @Override public T get(Object key) { - T value = (T) map.get(key); - if(value != null) { + if (value != null) { return value; } throw new NoSuchElementException("%s is missing".formatted(key)); @@ -55,7 +47,7 @@ public T get(Object key) { @Override public boolean hasKey(Object key) { - return map.containsKey(key); + return map.get(key) != null; } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java index 914396ab96..fcd4778042 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/observability/MongoObservationCommandListener.java @@ -113,10 +113,6 @@ public void commandStarted(CommandStartedEvent event) { Observation parent = observationFromContext(requestContext); - if (log.isDebugEnabled()) { - log.debug("Found the following observation passed from the mongo context [" + parent + "]"); - } - MongoHandlerContext observationContext = new MongoHandlerContext(connectionString, event, requestContext); observationContext.setRemoteServiceName("mongo"); @@ -141,22 +137,20 @@ public void commandStarted(CommandStartedEvent event) { @Override public void commandSucceeded(CommandSucceededEvent event) { - doInObservation(event.getRequestContext(), (observation, context) -> { + stopObservation(event.getRequestContext(), (observation, context) -> { context.setCommandSucceededEvent(event); if (log.isDebugEnabled()) { log.debug("Command succeeded - will stop observation [" + observation + "]"); } - - observation.stop(); }); } @Override public void commandFailed(CommandFailedEvent event) { - doInObservation(event.getRequestContext(), (observation, context) -> { + stopObservation(event.getRequestContext(), (observation, context) -> { context.setCommandFailedEvent(event); @@ -165,18 +159,17 @@ public void commandFailed(CommandFailedEvent event) { } observation.error(event.getThrowable()); - observation.stop(); }); } /** - * Performs the given action for the {@link Observation} and {@link MongoHandlerContext} if there is an ongoing Mongo - * Observation. Exceptions thrown by the action are relayed to the caller. + * Stops the {@link Observation} after applying {@code action} given {@link MongoHandlerContext} if there is an + * ongoing Mongo Observation. Exceptions thrown by the action are relayed to the caller. * * @param requestContext the context to extract the Observation from. * @param action the action to invoke. */ - private void doInObservation(@Nullable RequestContext requestContext, + private void stopObservation(@Nullable RequestContext requestContext, BiConsumer action) { if (requestContext == null) { @@ -188,7 +181,18 @@ private void doInObservation(@Nullable RequestContext requestContext, return; } - action.accept(observation, context); + try { + action.accept(observation, context); + } finally { + + observation.stop(); + + if (log.isDebugEnabled()) { + log.debug( + "Restoring parent observation [" + observation + "] for Mongo instrumentation and put it in Mongo context"); + } + requestContext.put(ObservationThreadLocalAccessor.KEY, observation.getContext().getParentObservation()); + } } /** @@ -210,7 +214,7 @@ private void doInObservation(@Nullable RequestContext requestContext, } if (log.isDebugEnabled()) { - log.debug("No observation was found - will not create any child observations"); + log.debug("No observation was found: Creating a new root observation"); } return null; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java index fe74a03bd6..dadb98ce2b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/observability/MongoObservationCommandListenerTests.java @@ -27,6 +27,7 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.assertj.core.api.Assertions; import org.bson.BsonDocument; import org.bson.BsonString; import org.junit.jupiter.api.BeforeEach; @@ -251,6 +252,46 @@ public String getName() { assertThat(meterRegistry).hasMeterWithName("custom.name.active"); } + @Test // GH-5064 + void completionRestoresParentObservation() { + + // given + Observation parent = Observation.start("name", observationRegistry); + observationRegistry.setCurrentObservationScope(parent.openScope()); + RequestContext traceRequestContext = getContext(); + + // when + listener.commandStarted(new CommandStartedEvent(traceRequestContext, 0, 0, null, "database", "insert", + new BsonDocument("collection", new BsonString("user")))); + + Assertions.assertThat((Observation) traceRequestContext.get(ObservationThreadLocalAccessor.KEY)).isNotNull() + .isNotEqualTo(parent); + + listener.commandSucceeded(new CommandSucceededEvent(traceRequestContext, 0, 0, null, "insert", null, null, 0)); + + Assertions.assertThat((Observation) traceRequestContext.get(ObservationThreadLocalAccessor.KEY)).isEqualTo(parent); + } + + @Test // GH-5064 + void failureRestoresParentObservation() { + + // given + Observation parent = Observation.start("name", observationRegistry); + observationRegistry.setCurrentObservationScope(parent.openScope()); + RequestContext traceRequestContext = getContext(); + + // when + listener.commandStarted(new CommandStartedEvent(traceRequestContext, 0, 0, null, "database", "insert", + new BsonDocument("collection", new BsonString("user")))); + + Assertions.assertThat((Observation) traceRequestContext.get(ObservationThreadLocalAccessor.KEY)).isNotNull() + .isNotEqualTo(parent); + + listener.commandFailed(new CommandFailedEvent(traceRequestContext, 0, 0, null, "insert", null, 0, null)); + + Assertions.assertThat((Observation) traceRequestContext.get(ObservationThreadLocalAccessor.KEY)).isEqualTo(parent); + } + private RequestContext getContext() { return ((SynchronousContextProvider) ContextProviderFactory.create(observationRegistry)).getContext(); }