Skip to content

Commit

Permalink
Use merged bean definitions for EntityCallback type lookup.
Browse files Browse the repository at this point in the history
We now use the merged bean definition to resolve the defined EntityCallback type.

Previously, we used just the bean definition that might have contained no type hints because of ASM-parsed configuration classes.

Closes #2853
  • Loading branch information
mp911de committed Jun 14, 2023
1 parent 2cbf0fb commit f1b7952
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
Expand Down Expand Up @@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks {
* @param beanFactory must not be {@literal null}.
*/
DefaultEntityCallbacks(BeanFactory beanFactory) {
this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory);
this.callbackDiscoverer = new EntityCallbackDiscoverer(
beanFactory instanceof GenericApplicationContext ac ? ac.getBeanFactory() : beanFactory);
}

@Override
Expand Down Expand Up @@ -93,8 +95,7 @@ public void addEntityCallback(EntityCallback<?> callback) {
this.callbackDiscoverer.addEntityCallback(callback);
}

static class SimpleEntityCallbackInvoker
implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
static class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker {

@Override
public <T> T invokeCallback(EntityCallback<T> callback, T entity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -28,10 +29,9 @@

import org.springframework.aop.framework.AopProxyUtils;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.lang.Nullable;
Expand All @@ -56,7 +56,6 @@ class EntityCallbackDiscoverer {
private final Map<Class<?>, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64);

@Nullable private ClassLoader beanClassLoader;
@Nullable private BeanFactory beanFactory;

private Object retrievalMutex = this.defaultRetriever;

Expand Down Expand Up @@ -104,12 +103,13 @@ void removeEntityCallback(EntityCallback<?> callback) {
* Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks
* get excluded early.
*
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on
* cached matching information.
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached
* matching information.
* @param callbackType the source callback type.
* @return a {@link Collection} of {@link EntityCallback}s.
* @see EntityCallback
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
<T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entity, ResolvableType callbackType) {

Class<?> sourceType = entity;
Expand All @@ -121,7 +121,7 @@ <T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entit
return (Collection) retriever.getEntityCallbacks();
}

if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader)
if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity, this.beanClassLoader)
&& (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader))) {

// Fully synchronized building and caching of a CallbackRetriever
Expand Down Expand Up @@ -163,8 +163,8 @@ ResolvableType resolveDeclaredEntityType(Class<?> callbackType) {
* @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes)
* @return the pre-filtered list of entity callbacks for the given entity and callback type.
*/
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType,
ResolvableType callbackType, @Nullable CallbackRetriever retriever) {
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType,
@Nullable CallbackRetriever retriever) {

List<EntityCallback<?>> allCallbacks = new ArrayList<>();
Set<EntityCallback<?>> callbacks;
Expand Down Expand Up @@ -198,16 +198,14 @@ private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType ent
}

/**
* Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set.
* Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}.
* Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by
* scanning the {@link BeanFactory}.
*
* @param beanFactory must not be {@literal null}.
* @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory)
*/
public void setBeanFactory(BeanFactory beanFactory) {

this.beanFactory = beanFactory;

if (beanFactory instanceof ConfigurableBeanFactory cbf) {

if (this.beanClassLoader == null) {
Expand All @@ -228,10 +226,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O

ReflectionUtils.doWithMethods(callbackType, methods::add, method -> {

if (!Modifier.isPublic(method.getModifiers())
|| method.getParameterCount() != args.length + 1
|| method.isBridge()
|| ReflectionUtils.isObjectMethod(method)) {
if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1
|| method.isBridge() || ReflectionUtils.isObjectMethod(method)) {
return false;
}

Expand All @@ -242,9 +238,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O
return methods.iterator().next();
}

throw new IllegalStateException(
"%s does not define a callback method accepting %s and %s additional arguments".formatted(
ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
throw new IllegalStateException("%s does not define a callback method accepting %s and %s additional arguments"
.formatted(ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
}

static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFunction(EntityCallback<T> callback,
Expand All @@ -267,10 +262,10 @@ static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFuncti
* Filter a callback early through checking its generically declared entity type before trying to instantiate it.
* <p>
* If this method returns {@literal true} for a given callback as a first pass, the callback instance will get
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)}
* call afterwards.
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call
* afterwards.
*
* @param callback the callback's type as determined by the BeanFactory.
* @param callbackType the callback's type as determined by the BeanFactory.
* @param entityType the entity type to check.
* @return whether the given callback should be included in the candidates for the given callback type.
*/
Expand All @@ -286,11 +281,9 @@ static boolean supportsEvent(ResolvableType callbackType, ResolvableType entityT
* @param callbackType the source type to check against.
* @return whether the given callback should be included in the candidates for the given callback type.
*/
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType,
ResolvableType callbackType) {
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType, ResolvableType callbackType) {

return callback instanceof EntityCallbackAdapter<?> provider
? provider.supports(callbackType, entityType)
return callback instanceof EntityCallbackAdapter<?> provider ? provider.supports(callbackType, entityType)
: callbackType.isInstance(callback) && supportsEvent(ResolvableType.forInstance(callback), entityType);
}

Expand All @@ -310,13 +303,11 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {

// We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection.
// If we don't get that, use simple inspection.
if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry)) {
if (!(beanFactory instanceof ConfigurableListableBeanFactory bf)) {
beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add);
return;
}

var bf = (ListableBeanFactory & BeanDefinitionRegistry) beanFactory;

for (var beanName : bf.getBeanNamesForType(EntityCallback.class)) {

EntityCallback<?> bean = EntityCallback.class.cast(bf.getBean(beanName));
Expand All @@ -328,7 +319,7 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
entityCallbacks.add(bean);
} else {

BeanDefinition definition = bf.getBeanDefinition(beanName);
BeanDefinition definition = bf.getMergedBeanDefinition(beanName);
entityCallbacks.add(new EntityCallbackAdapter<>(bean, definition.getResolvableType()));
}
}
Expand All @@ -340,8 +331,8 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
*
* @author Oliver Drotbohm
*/
private static record EntityCallbackAdapter<T>(EntityCallback<T> delegate, ResolvableType type)
implements EntityCallback<T> {
private record EntityCallbackAdapter<T> (EntityCallback<T> delegate,
ResolvableType type) implements EntityCallback<T> {

boolean supports(ResolvableType callbackType, ResolvableType entityType) {
return callbackType.isInstance(delegate) && supportsEvent(type, entityType);
Expand All @@ -351,15 +342,16 @@ boolean supports(ResolvableType callbackType, ResolvableType entityType) {
/**
* Cache key for {@link EntityCallback}, based on event type and source type.
*/
private static record CallbackCacheKey(ResolvableType callbackType, @Nullable Class<?> entityType)
implements Comparable<CallbackCacheKey> {
private record CallbackCacheKey(ResolvableType callbackType,
@Nullable Class<?> entityType) implements Comparable<CallbackCacheKey> {

private static final Comparator<CallbackCacheKey> COMPARATOR = Comparators.<CallbackCacheKey> nullsHigh() //
.thenComparing(it -> it.callbackType.toString()) //
.thenComparing(it -> it.entityType.getName());

@Override
public int compareTo(CallbackCacheKey other) {

return Comparators.<CallbackCacheKey> nullsHigh()
.thenComparing(it -> callbackType.toString())
.thenComparing(it -> entityType.getName()).compare(this, other);
return COMPARATOR.compare(this, other);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ void skipsInvocationUsingJava18ReflectiveTypeRejection() {
void detectsMultipleCallbacksWithinOneClass() {

var ctx = new AnnotationConfigApplicationContext(MultipleCallbacksInOneClassConfig.class);

var callbacks = new DefaultEntityCallbacks(ctx);

var personDocument = new PersonDocument(null, "Walter", null);
Expand Down

0 comments on commit f1b7952

Please sign in to comment.