Skip to content

Commit

Permalink
Wrapping an existing JmsEndpointRegistry; fixes gh-1200 (#1211)
Browse files Browse the repository at this point in the history
Wrapping an existing JmsEndpointRegistry

fixes #1200
  • Loading branch information
marcingrzejszczak committed Feb 12, 2019
1 parent 1e2d2f1 commit 8648a34
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import brave.Tracing;
import brave.jms.JmsTracing;
import brave.kafka.clients.KafkaTracing;
import brave.propagation.CurrentTraceContext;
import brave.spring.rabbit.SpringRabbitTracing;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
Expand Down Expand Up @@ -57,6 +56,8 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Role;
import org.springframework.jms.annotation.JmsListenerConfigurer;
import org.springframework.jms.config.JmsListenerEndpointRegistry;
import org.springframework.jms.config.TracingJmsListenerEndpointRegistry;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.AbstractMessageListenerContainer;
import org.springframework.kafka.listener.MessageListener;
Expand Down Expand Up @@ -129,6 +130,7 @@ SleuthKafkaAspect sleuthKafkaAspect(KafkaTracing kafkaTracing, Tracer tracer) {
@Configuration
@ConditionalOnProperty(value = "spring.sleuth.messaging.jms.enabled", matchIfMissing = true)
@ConditionalOnClass(JmsListenerConfigurer.class)
@ConditionalOnBean(JmsListenerEndpointRegistry.class)
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
protected static class SleuthJmsConfiguration {

Expand All @@ -149,12 +151,22 @@ TracingConnectionFactoryBeanPostProcessor tracingConnectionFactoryBeanPostProces
return new TracingConnectionFactoryBeanPostProcessor(beanFactory);
}

@Bean
JmsListenerConfigurer configureTracing(BeanFactory beanFactory,
JmsListenerEndpointRegistry defaultRegistry) {
return registrar -> {
TracingJmsBeanPostProcessor processor = tracingJmsBeanPostProcessor(
beanFactory);
JmsListenerEndpointRegistry registry = registrar.getEndpointRegistry();
registrar.setEndpointRegistry((JmsListenerEndpointRegistry) processor
.wrap(registry == null ? defaultRegistry : registry));
};
}

// Setup the tracing endpoint registry.
@Bean
JmsListenerConfigurer configureTracing(JmsTracing jmsTracing,
CurrentTraceContext current) {
return registrar -> registrar.setEndpointRegistry(
new TracingJmsListenerEndpointRegistry(jmsTracing, current));
TracingJmsBeanPostProcessor tracingJmsBeanPostProcessor(BeanFactory beanFactory) {
return new TracingJmsBeanPostProcessor(beanFactory);
}

}
Expand Down Expand Up @@ -325,3 +337,32 @@ public Object invoke(MethodInvocation invocation) throws Throwable {
}

}

class TracingJmsBeanPostProcessor implements BeanPostProcessor {

private final BeanFactory beanFactory;

TracingJmsBeanPostProcessor(BeanFactory beanFactory) {
this.beanFactory = beanFactory;
}

@Override
public Object postProcessAfterInitialization(Object bean, String beanName)
throws BeansException {
return wrap(bean);
}

Object wrap(Object bean) {
if (typeMatches(bean)) {
return new TracingJmsListenerEndpointRegistry(
(JmsListenerEndpointRegistry) bean, this.beanFactory);
}
return bean;
}

private boolean typeMatches(Object bean) {
return bean instanceof JmsListenerEndpointRegistry
&& !(bean instanceof TracingJmsListenerEndpointRegistry);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,23 @@

package org.springframework.cloud.sleuth.instrument.messaging;

import java.lang.reflect.Field;

import javax.jms.Connection;
import javax.jms.ConnectionFactory;
import javax.jms.JMSContext;
import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageConsumer;
import javax.jms.MessageListener;
import javax.jms.Session;
import javax.jms.XAConnection;
import javax.jms.XAConnectionFactory;
import javax.jms.XAJMSContext;

import brave.Span;
import brave.jms.JmsTracing;
import brave.propagation.CurrentTraceContext;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.jms.config.JmsListenerContainerFactory;
import org.springframework.jms.config.JmsListenerEndpoint;
import org.springframework.jms.config.JmsListenerEndpointRegistry;
import org.springframework.jms.config.MethodJmsListenerEndpoint;
import org.springframework.jms.config.SimpleJmsListenerEndpoint;
import org.springframework.jms.connection.CachingConnectionFactory;
import org.springframework.jms.listener.adapter.MessagingMessageListenerAdapter;
import org.springframework.jms.listener.endpoint.JmsMessageEndpointManager;
import org.springframework.lang.Nullable;

/**
* {@link BeanPostProcessor} wrapping around JMS {@link ConnectionFactory}.
Expand Down Expand Up @@ -244,160 +231,3 @@ private MessageListener wrappedDelegate() {
}

}

/**
* This ensures listeners end up continuing the trace from
* {@link MessageConsumer#receive()}.
*/
class TracingJmsListenerEndpointRegistry extends JmsListenerEndpointRegistry {

final JmsTracing jmsTracing;

final CurrentTraceContext current;

// Not all state can be copied without using reflection
final Field messageHandlerMethodFactoryField;

final Field embeddedValueResolverField;

TracingJmsListenerEndpointRegistry(JmsTracing jmsTracing,
CurrentTraceContext current) {
this.jmsTracing = jmsTracing;
this.current = current;
this.messageHandlerMethodFactoryField = tryField("messageHandlerMethodFactory");
this.embeddedValueResolverField = tryField("embeddedValueResolver");
}

@Nullable
static Field tryField(String name) {
try {
Field field = MethodJmsListenerEndpoint.class.getDeclaredField(name);
field.setAccessible(true);
return field;
}
catch (NoSuchFieldException e) {
return null;
}
}

@Nullable
static <T> T get(Object object, Field field) throws IllegalAccessException {
return (T) field.get(object);
}

@Override
public void registerListenerContainer(JmsListenerEndpoint endpoint,
JmsListenerContainerFactory<?> factory, boolean startImmediately) {
if (endpoint instanceof MethodJmsListenerEndpoint) {
endpoint = trace((MethodJmsListenerEndpoint) endpoint);
}
else if (endpoint instanceof SimpleJmsListenerEndpoint) {
endpoint = trace((SimpleJmsListenerEndpoint) endpoint);
}
super.registerListenerContainer(endpoint, factory, startImmediately);
}

/**
* This wraps the {@link SimpleJmsListenerEndpoint#getMessageListener()} delegate in a
* new span.
* @param source jms endpoint
* @return wrapped endpoint
*/
SimpleJmsListenerEndpoint trace(SimpleJmsListenerEndpoint source) {
MessageListener delegate = source.getMessageListener();
if (delegate == null) {
return source;
}
source.setMessageListener(this.jmsTracing.messageListener(delegate, false));
return source;
}

/**
* It would be better to trace by wrapping, but
* {@link MethodJmsListenerEndpoint#createMessageListenerInstance()}, is protected so
* we can't call it from outside code. In other words, a forwarding pattern can't be
* used. Instead, we copy state from the input.
* <p>
* NOTE: As {@linkplain MethodJmsListenerEndpoint} is neither final, nor effectively
* final. For this reason we can't ensure copying will get all state. For example, a
* subtype could hold state we aren't aware of, or change behavior. We can consider
* checking that input is not a subtype, and most conservatively leaving unknown
* subtypes untraced.
* @param source jms endpoint
* @return wrapped endpoint
*/
MethodJmsListenerEndpoint trace(MethodJmsListenerEndpoint source) {
// Skip out rather than incompletely copying the source
if (this.messageHandlerMethodFactoryField == null
|| this.embeddedValueResolverField == null) {
return source;
}

// We want the stock implementation, except we want to wrap the message listener
// in a new span
MethodJmsListenerEndpoint dest = new MethodJmsListenerEndpoint() {
@Override
protected MessagingMessageListenerAdapter createMessageListenerInstance() {
return new TracingMessagingMessageListenerAdapter(
TracingJmsListenerEndpointRegistry.this.jmsTracing,
TracingJmsListenerEndpointRegistry.this.current);
}
};

// set state from AbstractJmsListenerEndpoint
dest.setId(source.getId());
dest.setDestination(source.getDestination());
dest.setSubscription(source.getSubscription());
dest.setSelector(source.getSelector());
dest.setConcurrency(source.getConcurrency());

// set state from MethodJmsListenerEndpoint
dest.setBean(source.getBean());
dest.setMethod(source.getMethod());
dest.setMostSpecificMethod(source.getMostSpecificMethod());

try {
dest.setMessageHandlerMethodFactory(
get(source, this.messageHandlerMethodFactoryField));
dest.setEmbeddedValueResolver(get(source, this.embeddedValueResolverField));
}
catch (IllegalAccessException e) {
return source; // skip out rather than incompletely copying the source
}
return dest;
}

}

/**
* This wraps the message listener in a child span.
*/
final class TracingMessagingMessageListenerAdapter
extends MessagingMessageListenerAdapter {

final JmsTracing jmsTracing;

final CurrentTraceContext current;

TracingMessagingMessageListenerAdapter(JmsTracing jmsTracing,
CurrentTraceContext current) {
this.jmsTracing = jmsTracing;
this.current = current;
}

@Override
public void onMessage(Message message, Session session) throws JMSException {
Span span = this.jmsTracing.nextSpan(message).name("on-message").start();
try (CurrentTraceContext.Scope ws = this.current.newScope(span.context())) {
super.onMessage(message, session);
}
catch (JMSException | RuntimeException | Error e) {
span.error(e);
throw e;
}
finally {
span.finish();
}
}

}
Loading

0 comments on commit 8648a34

Please sign in to comment.