Skip to content

Consider @Primary annotation when using @MockBean, make @MockBean/@SpyBean behave consistently #11077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e6b3f9b
Consider @Primary annotation when using @MockBean
5ff8d21a-49f3-49a6-aff6-f2d456578715 Nov 17, 2017
ea1bb78
Revert findCandidateBeans changes
neiser Nov 18, 2017
48d50bc
Fix code formatting
neiser Nov 18, 2017
3d32be8
Make MockitoPostProcessor.getBeanName aware of @Primary
neiser Nov 18, 2017
9a1064e
Preserve @Primary flag for mocked bean
neiser Nov 18, 2017
b323d28
Improve error message if mocked bean not found
neiser Nov 18, 2017
d459dc6
Add @SpyBean test
5ff8d21a-49f3-49a6-aff6-f2d456578715 Nov 18, 2017
6ca9349
Fix SpyPrimaryBean test
neiser Nov 18, 2017
6eb00a9
Reorder SpyBean configuration tests, add another SpyBean test (fails)
neiser Nov 18, 2017
0d7b040
Adding TODOs
neiser Nov 18, 2017
4576e86
Add integration test for @SpyBean with qualifier (copied from @MockBe…
neiser Nov 18, 2017
a1b6f1e
Refactoring: Change String[] to Set<String> in registerSpy code path
neiser Nov 19, 2017
f554c3f
Refactoring: Move typeToMockOrSpy into base class Definition
neiser Nov 19, 2017
9982988
Refactoring: Rename variables in determinePrimaryCandidate
neiser Nov 19, 2017
f4b5fca
Make findCandidateBeans work with both SpyDefinition and MockDefinition
neiser Nov 19, 2017
acf6688
Split getBeanName into findBeanName and generate bean name
neiser Nov 19, 2017
6a26b26
Refactoring: Make parameter names of findBeanName consistent
neiser Nov 19, 2017
13ba76e
Remove superfluous try/catch
neiser Nov 19, 2017
68d7328
Remove confusing registerSpies method
neiser Nov 19, 2017
65fc16c
Replace determineBeanName by identical findBeanName
neiser Nov 19, 2017
18e8b1b
Move check for unique bean name into findBeanName
neiser Nov 19, 2017
dfa3ffa
Refactoring: Make method names for getBeanName consistent
neiser Nov 19, 2017
4be1f83
Refactoring: Code cleanups in MockitoPostProcessor
neiser Nov 19, 2017
330baea
Use findCandidateBeans when registering spies, fixes tests with quali…
neiser Nov 19, 2017
abb40a5
Refactoring: Reorder registerSpy
neiser Nov 19, 2017
23c36b9
Simplify registerSpy, make it look similar to registerMock
neiser Nov 19, 2017
d0d94de
Move responsibility of finding bean names to separate class
neiser Nov 19, 2017
0b50005
Reformat MockitoPostProcessor
neiser Nov 19, 2017
69970b8
Reformat MockitoBeanNameFinder
neiser Nov 19, 2017
778ef7c
Make check style happy
neiser Nov 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.boot.test.mock.mockito;

import org.springframework.core.ResolvableType;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;

/**
Expand All @@ -36,12 +38,16 @@ abstract class Definition {

private final QualifierDefinition qualifier;

private final ResolvableType type;

Definition(String name, MockReset reset, boolean proxyTargetAware,
QualifierDefinition qualifier) {
QualifierDefinition qualifier, ResolvableType type) {
Assert.notNull(type, "type must not be null");
this.name = name;
this.reset = (reset != null ? reset : MockReset.AFTER);
this.proxyTargetAware = proxyTargetAware;
this.qualifier = qualifier;
this.type = type;
}

/**
Expand Down Expand Up @@ -76,6 +82,14 @@ public QualifierDefinition getQualifier() {
return this.qualifier;
}

/**
* Get type to mock or spy.
* @return the type; never {@code null}
*/
public ResolvableType getType() {
return this.type;
}

@Override
public int hashCode() {
int result = 1;
Expand All @@ -84,6 +98,7 @@ public int hashCode() {
result = MULTIPLIER * result
+ ObjectUtils.nullSafeHashCode(this.proxyTargetAware);
result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.qualifier);
result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.type);
return result;
}

Expand All @@ -102,6 +117,7 @@ public boolean equals(Object obj) {
result = result && ObjectUtils.nullSafeEquals(this.proxyTargetAware,
other.proxyTargetAware);
result = result && ObjectUtils.nullSafeEquals(this.qualifier, other.qualifier);
result = result && ObjectUtils.nullSafeEquals(this.type, other.type);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import org.springframework.core.ResolvableType;
import org.springframework.core.style.ToStringCreator;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;

Expand All @@ -40,8 +39,6 @@ class MockDefinition extends Definition {

private static final int MULTIPLIER = 31;

private final ResolvableType typeToMock;

private final Set<Class<?>> extraInterfaces;

private final Answers answer;
Expand All @@ -51,9 +48,7 @@ class MockDefinition extends Definition {
MockDefinition(String name, ResolvableType typeToMock, Class<?>[] extraInterfaces,
Answers answer, boolean serializable, MockReset reset,
QualifierDefinition qualifier) {
super(name, reset, false, qualifier);
Assert.notNull(typeToMock, "TypeToMock must not be null");
this.typeToMock = typeToMock;
super(name, reset, false, qualifier, typeToMock);
this.extraInterfaces = asClassSet(extraInterfaces);
this.answer = (answer != null ? answer : Answers.RETURNS_DEFAULTS);
this.serializable = serializable;
Expand All @@ -72,7 +67,7 @@ private Set<Class<?>> asClassSet(Class<?>[] classes) {
* @return the type to mock; never {@code null}
*/
public ResolvableType getTypeToMock() {
return this.typeToMock;
return super.getType();
}

/**
Expand Down Expand Up @@ -102,7 +97,6 @@ public boolean isSerializable() {
@Override
public int hashCode() {
int result = super.hashCode();
result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.typeToMock);
result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.extraInterfaces);
result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.answer);
result = MULTIPLIER * result + Boolean.hashCode(this.serializable);
Expand All @@ -119,7 +113,6 @@ public boolean equals(Object obj) {
}
MockDefinition other = (MockDefinition) obj;
boolean result = super.equals(obj);
result = result && ObjectUtils.nullSafeEquals(this.typeToMock, other.typeToMock);
result = result && ObjectUtils.nullSafeEquals(this.extraInterfaces,
other.extraInterfaces);
result = result && ObjectUtils.nullSafeEquals(this.answer, other.answer);
Expand All @@ -130,7 +123,7 @@ public boolean equals(Object obj) {
@Override
public String toString() {
return new ToStringCreator(this).append("name", getName())
.append("typeToMock", this.typeToMock)
.append("typeToMock", super.getType())
.append("extraInterfaces", this.extraInterfaces)
.append("answer", this.answer).append("serializable", this.serializable)
.append("reset", getReset()).toString();
Expand All @@ -153,7 +146,7 @@ public <T> T createMock(String name) {
if (this.serializable) {
settings.serializable();
}
return (T) Mockito.mock(this.typeToMock.resolve(), settings);
return (T) Mockito.mock(this.getTypeToMock().resolve(), settings);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.test.mock.mockito;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.TreeSet;

import javax.annotation.Nullable;

import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultBeanNameGenerator;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.core.ResolvableType;
import org.springframework.util.StringUtils;

/**
* Encapsulates the task to find the bean name which can be mocked or spied. Offers
* {@link MockitoBeanNameFinder#getOrGenerateBeanName} for that purpose. Used by
* {@link MockitoPostProcessor}.
*
* @author Andreas Neiser
*/
final class MockitoBeanNameFinder {

private MockitoBeanNameFinder() {
// only static method calls
}

private static final String FACTORY_BEAN_OBJECT_TYPE = "factoryBeanObjectType";

private static final BeanNameGenerator BEAN_NAME_GENERATOR = new DefaultBeanNameGenerator();

static String getOrGenerateBeanName(ConfigurableListableBeanFactory beanFactory,
BeanDefinitionRegistry registry, Definition definition,
RootBeanDefinition beanDefinition) {
Set<String> existingBeans = findCandidateBeans(beanFactory, definition);
if (existingBeans.isEmpty()) {
return BEAN_NAME_GENERATOR.generateBeanName(beanDefinition, registry);
}
return getBeanName(registry, existingBeans, definition);
}

private static String getBeanName(BeanDefinitionRegistry registry,
Set<String> existingBeanNames, Definition definition) {
if (StringUtils.hasText(definition.getName())) {
return definition.getName();
}
if (existingBeanNames.size() == 1) {
return existingBeanNames.iterator().next();
}
String beanName = findPrimaryBeanName(registry, existingBeanNames,
definition.getType());
if (beanName == null) {
throw new IllegalStateException("Unable to register bean "
+ definition.getType()
+ " expected a single matching/primary bean to replace but found "
+ existingBeanNames);
}
return beanName;
}

@Nullable
private static String findPrimaryBeanName(BeanDefinitionRegistry registry,
Set<String> existingBeanNames, ResolvableType type) {
String primaryBeanName = null;
for (String existingBeanName : existingBeanNames) {
BeanDefinition beanDefinition = registry.getBeanDefinition(existingBeanName);
if (beanDefinition.isPrimary()) {
if (primaryBeanName != null) {
throw new NoUniqueBeanDefinitionException(type.resolve(),
existingBeanNames.size(),
"more than one 'primary' bean found among candidates: "
+ existingBeanNames);
}
primaryBeanName = existingBeanName;
}
}
return primaryBeanName;
}

private static Set<String> findCandidateBeans(
ConfigurableListableBeanFactory beanFactory, Definition definition) {
QualifierDefinition qualifier = definition.getQualifier();
Set<String> candidates = new TreeSet<>();
for (String candidate : getExistingBeans(beanFactory, definition.getType())) {
if (qualifier == null || qualifier.matches(beanFactory, candidate)) {
candidates.add(candidate);
}
}
return candidates;
}

private static Set<String> getExistingBeans(
ConfigurableListableBeanFactory beanFactory, ResolvableType type) {
Set<String> beans = new LinkedHashSet<>(
Arrays.asList(beanFactory.getBeanNamesForType(type)));
String resolvedTypeName = type.resolve(Object.class).getName();
for (String beanName : beanFactory.getBeanNamesForType(FactoryBean.class)) {
beanName = BeanFactoryUtils.transformedBeanName(beanName);
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
if (resolvedTypeName
.equals(beanDefinition.getAttribute(FACTORY_BEAN_OBJECT_TYPE))) {
beans.add(beanName);
}
}
beans.removeIf(ScopedProxyUtils::isScopedTarget);
return beans;
}
}
Loading