diff --git a/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java b/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java index 2109be3791ec..be9ddfff99b8 100644 --- a/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java +++ b/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java @@ -18,13 +18,11 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; @@ -37,6 +35,8 @@ import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.BeanPostProcessor; @@ -70,12 +70,15 @@ * {@link MockBean @MockBean}. * * @author Phillip Webb + * @author Andy Wilkinson * @since 1.4.0 */ public class MockitoPostProcessor extends InstantiationAwareBeanPostProcessorAdapter implements BeanClassLoaderAware, BeanFactoryAware, BeanFactoryPostProcessor, Ordered { + private static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType"; + private static final String BEAN_NAME = MockitoPostProcessor.class.getName(); private static final String CONFIGURATION_CLASS_ATTRIBUTE = Conventions @@ -240,8 +243,16 @@ private void registerSpy(ConfigurableListableBeanFactory beanFactory, private String[] getExistingBeans(ConfigurableListableBeanFactory beanFactory, Class type) { - List beans = new ArrayList( + Set beans = new LinkedHashSet( Arrays.asList(beanFactory.getBeanNamesForType(type))); + for (String beanName : beanFactory.getBeanNamesForType(FactoryBean.class)) { + beanName = BeanFactoryUtils.transformedBeanName(beanName); + BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName); + if (type.getName() + .equals(beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE))) { + beans.add(beanName); + } + } for (Iterator iterator = beans.iterator(); iterator.hasNext();) { if (isScopedTarget(iterator.next())) { iterator.remove(); diff --git a/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java b/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java index 7ad46cf81fc1..a88c6b98241d 100644 --- a/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java +++ b/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java @@ -19,17 +19,23 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.internal.util.MockUtil; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.boot.test.mock.mockito.example.ExampleService; import org.springframework.boot.test.mock.mockito.example.FailingExampleService; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static org.assertj.core.api.Assertions.assertThat; + /** * Test for {@link MockitoPostProcessor}. See also the integration tests. * * @author Phillip Webb + * @author Andy Wilkinson */ public class MockitoPostProcessorTests { @@ -49,6 +55,31 @@ public void cannotMockMultipleBeans() { context.refresh(); } + @Test + public void canMockBeanProducedByFactoryBeanWithObjectTypeAttribute() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + MockitoPostProcessor.register(context); + RootBeanDefinition factoryBeanDefinition = new RootBeanDefinition( + TestFactoryBean.class); + factoryBeanDefinition.setAttribute("factoryBeanObjectType", + SomeInterface.class.getName()); + context.registerBeanDefinition("beanToBeMocked", factoryBeanDefinition); + context.register(MockedFactoryBean.class); + context.refresh(); + assertThat(new MockUtil().isMock(context.getBean("beanToBeMocked"))).isTrue(); + } + + @Configuration + @MockBean(SomeInterface.class) + static class MockedFactoryBean { + + @Bean + public TestFactoryBean testFactoryBean() { + return new TestFactoryBean(); + } + + } + @Configuration @MockBean(ExampleService.class) static class MultipleBeans { @@ -65,4 +96,31 @@ public ExampleService example2() { } + static class TestFactoryBean implements FactoryBean { + + @Override + public Object getObject() throws Exception { + return new TestBean(); + } + + @Override + public Class getObjectType() { + return null; + } + + @Override + public boolean isSingleton() { + return true; + } + + } + + interface SomeInterface { + + } + + static class TestBean implements SomeInterface { + + } + }