Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InjectMock should not create a new contextual instance #32949

Merged
merged 1 commit into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ Collection<Resource> generateSyntheticBean(BeanInfo bean) {
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -487,6 +488,7 @@ Collection<Resource> generateProducerMethodBean(BeanInfo bean, MethodInfo produc
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -567,6 +569,7 @@ Collection<Resource> generateProducerFieldBean(BeanInfo bean, FieldInfo producer
implementGetStereotypes(bean, beanCreator, stereotypes.getFieldDescriptor());
}
implementGetBeanClass(bean, beanCreator);
implementGetImplementationClass(bean, beanCreator);
implementGetName(bean, beanCreator);
if (bean.isDefaultBean()) {
implementIsDefaultBean(bean, beanCreator);
Expand Down Expand Up @@ -2068,6 +2071,13 @@ protected void implementGetBeanClass(BeanInfo bean, ClassCreator beanCreator) {
getBeanClass.returnValue(getBeanClass.loadClass(bean.getBeanClass().toString()));
}

protected void implementGetImplementationClass(BeanInfo bean, ClassCreator beanCreator) {
MethodCreator getImplementationClass = beanCreator.getMethodCreator("getImplementationClass", Class.class)
.setModifiers(ACC_PUBLIC);
getImplementationClass.returnValue(bean.getImplClazz() != null ? getImplementationClass.loadClass(bean.getImplClazz())
: getImplementationClass.loadNull());
}

protected void implementGetName(BeanInfo bean, ClassCreator beanCreator) {
if (bean.getName() != null) {
MethodCreator getName = beanCreator.getMethodCreator("getName", String.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ default int getPriority() {
return 0;
}

/**
* The return value depends on the {@link #getKind()}.
*
* <ul>
* <li>For managed beans, interceptors, decorators and built-in beans, the bean class is returned.</li>
* <li>For a producer method, the class of the return type is returned.</li>
* <li>For a producer field, the class of the field is returned.</li>
* <li>For a synthetic bean, the implementation class defined by the registrar is returned.
* </ul>
*
* @return the implementation class, or null in case of a producer of a primitive type or an array
* @see Kind
*/
default Class<?> getImplementationClass() {
return getBeanClass();
}

enum Kind {

CLASS,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkus.it.mockbean;

import java.util.concurrent.atomic.AtomicBoolean;

import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.RequestScoped;

import io.quarkus.arc.Unremovable;

@Unremovable
@RequestScoped
public class RequestScopedFoo {

static final AtomicBoolean CONSTRUCTED = new AtomicBoolean();

public String ping() {
return "bar";
}

@PostConstruct
void init() {
CONSTRUCTED.set(true);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkus.it.mockbean;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.Mockito.when;

import org.junit.jupiter.api.Test;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.mockito.InjectMock;

@QuarkusTest
class RequestScopedFooMockTest {

@InjectMock
RequestScopedFoo foo;

@Test
void testMock() {
when(foo.ping()).thenReturn("pong");
assertEquals("pong", foo.ping());
assertFalse(RequestScopedFoo.CONSTRUCTED.get());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.Subclass;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.mockito.InjectMock;

Expand All @@ -29,24 +27,25 @@ public void afterConstruct(Object testInstance) {
InjectMock injectMockAnnotation = field.getAnnotation(InjectMock.class);
if (injectMockAnnotation != null) {
boolean returnsDeepMocks = injectMockAnnotation.returnsDeepMocks();
Object contextualReference = getContextualReference(testInstance, field, InjectMock.class);
Optional<Object> result = createMockAndSetTestField(testInstance, field, contextualReference,
InstanceHandle<?> beanHandle = getBeanHandle(testInstance, field, InjectMock.class);
Optional<Object> result = createMockAndSetTestField(testInstance, field, beanHandle,
new MockConfiguration(returnsDeepMocks));
if (result.isPresent()) {
MockitoMocksTracker.track(testInstance, result.get(), contextualReference);
MockitoMocksTracker.track(testInstance, result.get(), beanHandle.get());
}
}
}
current = current.getSuperclass();
}
}

private Optional<Object> createMockAndSetTestField(Object testInstance, Field field, Object contextualReference,
private Optional<Object> createMockAndSetTestField(Object testInstance, Field field, InstanceHandle<?> beanHandle,
MockConfiguration mockConfiguration) {
Class<?> implementationClass = getImplementationClass(contextualReference);
Class<?> implementationClass = beanHandle.getBean().getImplementationClass();
Object mock;
boolean isNew;
Optional<Object> currentMock = MockitoMocksTracker.currentMock(testInstance, contextualReference);
// Note that beanHandle.get() returns a client proxy for normal scoped beans; i.e. the contextual instance is not created
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to figure this out on my own only to see this comment a few seconds later :-)

Optional<Object> currentMock = MockitoMocksTracker.currentMock(testInstance, beanHandle.get());
if (currentMock.isPresent()) {
mock = currentMock.get();
isNew = false;
Expand All @@ -71,15 +70,7 @@ private Optional<Object> createMockAndSetTestField(Object testInstance, Field fi
}
}

/**
* Contextual reference of a normal scoped bean is a client proxy.
*
* @param testInstance
* @param field
* @param annotationType
* @return a contextual reference of a bean
*/
static Object getContextualReference(Object testInstance, Field field, Class<? extends Annotation> annotationType) {
static InstanceHandle<?> getBeanHandle(Object testInstance, Field field, Class<? extends Annotation> annotationType) {
Type fieldType = field.getGenericType();
ArcContainer container = Arc.container();
BeanManager beanManager = container.beanManager();
Expand All @@ -100,15 +91,7 @@ static Object getContextualReference(Object testInstance, Field field, Class<? e
+ ". Offending field is " + field.getName() + " of test class "
+ testInstance.getClass());
}
return handle.get();
}

static Class<?> getImplementationClass(Object contextualReference) {
// Unwrap the client proxy if needed
Object contextualInstance = ClientProxy.unwrap(contextualReference);
// If the contextual instance is an intercepted subclass then mock the extended implementation class
return contextualInstance instanceof Subclass ? contextualInstance.getClass().getSuperclass()
: contextualInstance.getClass();
return handle;
}

static Annotation[] getQualifiers(Field fieldToMock, BeanManager beanManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.mockito.Mockito;

import io.quarkus.arc.ClientProxy;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.test.junit.callback.QuarkusTestAfterConstructCallback;
import io.quarkus.test.junit.mockito.InjectSpy;

Expand All @@ -18,22 +19,22 @@ public void afterConstruct(Object testInstance) {
for (Field field : current.getDeclaredFields()) {
InjectSpy injectSpyAnnotation = field.getAnnotation(InjectSpy.class);
if (injectSpyAnnotation != null) {
Object contextualReference = CreateMockitoMocksCallback.getContextualReference(testInstance, field,
InstanceHandle<?> beanHandle = CreateMockitoMocksCallback.getBeanHandle(testInstance, field,
InjectSpy.class);
Object spy = createSpyAndSetTestField(testInstance, field, contextualReference,
Object spy = createSpyAndSetTestField(testInstance, field, beanHandle,
injectSpyAnnotation.delegate());
MockitoMocksTracker.track(testInstance, spy, contextualReference);
MockitoMocksTracker.track(testInstance, spy, beanHandle.get());
}
}
current = current.getSuperclass();
}
}

private Object createSpyAndSetTestField(Object testInstance, Field field, Object contextualReference, boolean delegate) {
private Object createSpyAndSetTestField(Object testInstance, Field field, InstanceHandle<?> beanHandle, boolean delegate) {
Object spy;
Object contextualInstance = ClientProxy.unwrap(contextualReference);
Object contextualInstance = ClientProxy.unwrap(beanHandle.get());
if (delegate) {
spy = Mockito.mock(CreateMockitoMocksCallback.getImplementationClass(contextualReference),
spy = Mockito.mock(beanHandle.getBean().getImplementationClass(),
AdditionalAnswers.delegatesTo(contextualInstance));
} else {
// Unwrap the client proxy if needed
Expand Down