Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix entity-manager retrieval in spring-data-jpa #38323

Merged
merged 1 commit into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class AdditionalJpaOperations {
public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass, String query,
String countQuery, Sort sort, Map<String, Object> params) {
String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params));
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery);
JpaOperations.bindParameters(jpaQuery, params);
return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params);
Expand All @@ -47,14 +47,14 @@ public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class
public static PanacheQuery<?> find(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass, String query,
String countQuery, Sort sort, Object... params) {
String findQuery = createFindQuery(entityClass, query, jpaOperations.paramCount(params));
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Query jpaQuery = em.createQuery(sort != null ? findQuery + toOrderBy(sort) : findQuery);
JpaOperations.bindParameters(jpaQuery, params);
return new CustomCountPanacheQuery(em, jpaQuery, countQuery, params);
}

public static long deleteAllWithCascade(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
//detecting the case where there are cascade-delete associations, and do the bulk delete query otherwise.
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
Expand All @@ -77,7 +77,7 @@ public static long deleteAllWithCascade(AbstractJpaOperations<?> jpaOperations,
* @return true if cascading delete is needed. False otherwise
*/
private static boolean deleteOnCascadeDetected(AbstractJpaOperations<?> jpaOperations, Class<?> entityClass) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
Metamodel metamodel = em.getMetamodel();
EntityType<?> entity1 = metamodel.entity(entityClass);
Set<Attribute<?, ?>> declaredAttributes = ((EntityTypeImpl) entity1).getDeclaredAttributes();
Expand All @@ -96,7 +96,7 @@ private static boolean deleteOnCascadeDetected(AbstractJpaOperations<?> jpaOpera

public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<PanacheQueryType> jpaOperations,
Class<?> entityClass, String query, Object... params) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
List<?> objects = jpaOperations.list(jpaOperations.find(entityClass, query, params));
Expand All @@ -112,7 +112,7 @@ public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<Pa
public static <PanacheQueryType> long deleteWithCascade(AbstractJpaOperations<PanacheQueryType> jpaOperations,
Class<?> entityClass, String query,
Map<String, Object> params) {
EntityManager em = jpaOperations.getEntityManager();
EntityManager em = jpaOperations.getEntityManager(entityClass);
if (deleteOnCascadeDetected(jpaOperations, entityClass)) {
int count = 0;
List<?> objects = jpaOperations.list(jpaOperations.find(entityClass, query, params));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr
// and if so generate the implementation while also keeping the proper records

generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);
generateSaveAndFlush(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);
generateSaveAll(classCreator, entityClassFieldDescriptor, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
generateFlush(classCreator, generatedClassName, allMethodsToBeImplementedToResult);
Expand Down Expand Up @@ -121,7 +121,8 @@ public void add(ClassCreator classCreator, FieldDescriptor entityClassFieldDescr

private void generateSave(ClassCreator classCreator, String generatedClassName,
DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult,
FieldDescriptor entityClassFieldDescriptor) {

MethodDescriptor saveDescriptor = MethodDescriptor.ofMethod(generatedClassName, "save", entityTypeStr,
entityTypeStr);
Expand All @@ -144,7 +145,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
entity);
BranchResult isNewBranch = save.ifTrue(isNew);
generatePersistAndReturn(entity, isNewBranch.trueBranch());
generateMergeAndReturn(entity, isNewBranch.falseBranch());
generateMergeAndReturn(entity, isNewBranch.falseBranch(), entityClassFieldDescriptor);
} else {
AnnotationTarget idAnnotationTarget = getIdAnnotationTarget(entityDotName, index);
ResultHandle idValue = generateObtainValue(save, entityDotName, entity, idAnnotationTarget);
Expand All @@ -167,7 +168,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
versionValueTarget.get());
BranchResult versionValueIsNullBranch = save.ifNull(versionValue);
generatePersistAndReturn(entity, versionValueIsNullBranch.trueBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch());
generateMergeAndReturn(entity, versionValueIsNullBranch.falseBranch(), entityClassFieldDescriptor);
}

BytecodeCreator idValueUnset;
Expand All @@ -192,7 +193,7 @@ private void generateSave(ClassCreator classCreator, String generatedClassName,
idValueUnset = idValueNullBranch.trueBranch();
}
generatePersistAndReturn(entity, idValueUnset);
generateMergeAndReturn(entity, idValueSet);
generateMergeAndReturn(entity, idValueSet, entityClassFieldDescriptor);
}
}
try (MethodCreator bridgeSave = classCreator.getMethodCreator(bridgeSaveDescriptor)) {
Expand Down Expand Up @@ -236,10 +237,13 @@ private void generatePersistAndReturn(ResultHandle entity, BytecodeCreator bytec
bytecodeCreator.returnValue(entity);
}

private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator) {
private void generateMergeAndReturn(ResultHandle entity, BytecodeCreator bytecodeCreator,
FieldDescriptor entityClassFieldDescriptor) {
ResultHandle entityClass = bytecodeCreator.readInstanceField(entityClassFieldDescriptor, bytecodeCreator.getThis());
ResultHandle entityManager = bytecodeCreator.invokeVirtualMethod(
ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class),
bytecodeCreator.readStaticField(operationsField));
ofMethod(AbstractJpaOperations.class, "getEntityManager", EntityManager.class, Class.class),
bytecodeCreator.readStaticField(operationsField),
entityClass);
entity = bytecodeCreator.invokeInterfaceMethod(
MethodDescriptor.ofMethod(EntityManager.class, "merge", Object.class, Object.class),
entityManager, entity);
Expand Down Expand Up @@ -280,7 +284,7 @@ private Type getTypeOfTarget(AnnotationTarget idAnnotationTarget) {

private void generateSaveAndFlush(ClassCreator classCreator,
String generatedClassName, DotName entityDotName, String entityTypeStr,
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult) {
Map<MethodDescriptor, Boolean> allMethodsToBeImplementedToResult, FieldDescriptor entityClassFieldDescriptor) {

MethodDescriptor saveAndFlushDescriptor = MethodDescriptor.ofMethod(generatedClassName, "saveAndFlush", entityTypeStr,
entityTypeStr);
Expand All @@ -298,7 +302,7 @@ private void generateSaveAndFlush(ClassCreator classCreator,
// we need to force the generation of findById since this method depends on it
allMethodsToBeImplementedToResult.put(save, false);
generateSave(classCreator, generatedClassName, entityDotName, entityTypeStr,
allMethodsToBeImplementedToResult);
allMethodsToBeImplementedToResult, entityClassFieldDescriptor);

try (MethodCreator saveAndFlush = classCreator.getMethodCreator(saveAndFlushDescriptor)) {
saveAndFlush.addAnnotation(Transactional.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
package io.quarkus.spring.data.deployment.multiple_pu;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.function.Supplier;

import jakarta.inject.Inject;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;

import io.quarkus.narayana.jta.QuarkusTransaction;
import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntity;
import io.quarkus.spring.data.deployment.multiple_pu.first.FirstEntityRepository;
import io.quarkus.spring.data.deployment.multiple_pu.second.SecondEntity;
Expand All @@ -21,6 +32,17 @@ public class MultiplePersistenceUnitConfigTest {
PanacheTestResource.class)
.addAsResource("application-multiple-persistence-units.properties", "application.properties"));

@Inject
private FirstEntityRepository repository1;
@Inject
private SecondEntityRepository repository2;

@BeforeEach
void beforeEach() {
repository1.deleteAll();
repository2.deleteAll();
}

@Test
public void panacheOperations() {
/**
Expand All @@ -35,4 +57,64 @@ public void panacheOperations() {
RestAssured.when().get("/persistence-unit/second/name-1").then().body(Matchers.is("1"));
RestAssured.when().get("/persistence-unit/second/name-2").then().body(Matchers.is("2"));
}

@Test
public void entityLifecycle() {
var detached = repository2.save(new SecondEntity());
assertThat(detached.id).isNotNull();
assertThat(inTx(repository2::count)).isEqualTo(1);

detached.name = "name";
repository2.save(detached);
assertThat(inTx(repository2::count)).isEqualTo(1);

inTx(() -> {
var lazyRef = repository2.getOne(detached.id);
assertThat(lazyRef.name).isEqualTo(detached.name);
return null;
});

repository2.deleteByName("otherThan" + detached.name);
assertThat(inTx(() -> repository2.findById(detached.id))).isPresent();

repository2.deleteByName(detached.name);
assertThat(inTx(() -> repository2.findById(detached.id))).isEmpty();
}

@Test
void pagedQueries() {
var newEntity = new SecondEntity();
newEntity.name = "name";
var detached = repository2.save(newEntity);

Pageable pageable = PageRequest.of(0, 10, Sort.Direction.DESC, "id");

var page = inTx(() -> repository2.findByName(detached.name, pageable));
assertThat(page.getContent()).extracting(e -> e.id).containsExactly(detached.id);

var pageIndexParam = inTx(() -> repository2.findByNameQueryIndexed(detached.name, pageable));
assertThat(pageIndexParam.getContent()).extracting(e -> e.id).containsExactly(detached.id);

var pageNamedParam = inTx(() -> repository2.findByNameQueryNamed(detached.name, pageable));
assertThat(pageNamedParam.getContent()).extracting(e -> e.id).containsExactly(detached.id);
}

@Test
void cascading() {
var newParent = new SecondEntity();
newParent.name = "parent";
var newChild = new SecondEntity();
newChild.name = "child";
newParent.child = newChild;
var detachedParent = repository2.save(newParent);

assertThat(inTx(repository2::count)).isEqualTo(2);

repository2.deleteByName(detachedParent.name);
assertThat(inTx(repository2::count)).isZero();
}

private <T> T inTx(Supplier<T> action) {
return QuarkusTransaction.requiringNew().call(action::get);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package io.quarkus.spring.data.deployment.multiple_pu.second;

import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;
import jakarta.persistence.*;

@Entity
public class SecondEntity {
Expand All @@ -12,4 +10,7 @@ public class SecondEntity {
public Long id;

public String name;

@OneToOne(cascade = CascadeType.ALL)
public SecondEntity child;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package io.quarkus.spring.data.deployment.multiple_pu.second;

import java.util.Optional;

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;

@Repository
Expand All @@ -8,4 +14,21 @@ public interface SecondEntityRepository extends org.springframework.data.reposit
SecondEntity save(SecondEntity entity);

long count();

Optional<SecondEntity> findById(Long id);

SecondEntity getOne(Long id);

void deleteAll();

void deleteByName(String name);

Page<SecondEntity> findByName(String name, Pageable pageable);

@Query(value = "SELECT se FROM SecondEntity se WHERE name=?1", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=?1")
Page<SecondEntity> findByNameQueryIndexed(String name, Pageable pageable);

@Query(value = "SELECT se FROM SecondEntity se WHERE name=:name", countQuery = "SELECT COUNT(*) FROM SecondEntity se WHERE name=:name")
Page<SecondEntity> findByNameQueryNamed(@Param("name") String name, Pageable pageable);

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static void deleteAll(AbstractJpaOperations<PanacheQuery<?>> operations,
}

public static Object getOne(AbstractJpaOperations<PanacheQuery<?>> operations, Class<?> entityClass, Object id) {
return operations.getEntityManager().getReference(entityClass, id);
return operations.getEntityManager(entityClass).getReference(entityClass, id);
}

public static void clear(Class<?> clazz) {
Expand Down
Loading