Skip to content

Commit

Permalink
Using Java8 stream to find jdbc connection
Browse files Browse the repository at this point in the history
refs #29
  • Loading branch information
rmpestano committed Aug 28, 2016
1 parent 1319e2c commit 46f9814
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 32 deletions.
73 changes: 41 additions & 32 deletions junit5/src/main/java/com/github/dbunit/junit5/DBUnitExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import com.github.dbunit.rules.api.dataset.DataSetModel;
import com.github.dbunit.rules.api.dataset.ExpectedDataSet;
import com.github.dbunit.rules.dataset.DataSetExecutorImpl;
import org.junit.jupiter.api.extension.*;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.TestExtensionContext;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Optional;

import static com.github.dbunit.rules.util.EntityManagerProvider.em;
import static com.github.dbunit.rules.util.EntityManagerProvider.isEntityManagerActive;
Expand All @@ -20,36 +25,35 @@
public class DBUnitExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {



@Override
public void beforeTestExecution(TestExtensionContext testExtensionContext) throws Exception {

if (!shouldCreateDataSet(testExtensionContext)) {
return;
}

if(isEntityManagerActive()){
if (isEntityManagerActive()) {
em().clear();
}

ConnectionHolder connectionHolder = findTestConnection(testExtensionContext);

DataSet annotation = testExtensionContext.getTestMethod().get().getAnnotation(DataSet.class);
if(annotation == null){
if (annotation == null) {
//try to infer from class level annotation
annotation = testExtensionContext.getTestClass().get().getAnnotation(DataSet.class);
}

if(annotation == null){
throw new RuntimeException("Could not find DataSet annotation for test "+testExtensionContext.getTestMethod().get().getName());
if (annotation == null) {
throw new RuntimeException("Could not find DataSet annotation for test " + testExtensionContext.getTestMethod().get().getName());
}

final DataSetModel model = new DataSetModel().from(annotation);
DataSetExecutor executor = DataSetExecutorImpl.instance(model.getExecutorId(), connectionHolder);

ExtensionContext.Namespace namespace = getExecutorNamespace(testExtensionContext);//one executor per test class
testExtensionContext.getStore(namespace).put("executor",executor);
testExtensionContext.getStore(namespace).put("model",model);
testExtensionContext.getStore(namespace).put("executor", executor);
testExtensionContext.getStore(namespace).put("model", model);
try {
executor.createDataSet(model);
} catch (final Exception e) {
Expand All @@ -71,11 +75,10 @@ private boolean shouldCompareDataSet(TestExtensionContext testExtensionContext)
}



@Override
public void afterTestExecution(TestExtensionContext testExtensionContext) throws Exception {

if(shouldCompareDataSet(testExtensionContext)){
if (shouldCompareDataSet(testExtensionContext)) {
ExpectedDataSet expectedDataSet = testExtensionContext.getTestMethod().get().getAnnotation(ExpectedDataSet.class);
if (expectedDataSet == null) {
//try to infer from class level annotation
Expand All @@ -100,37 +103,43 @@ private ExtensionContext.Namespace getExecutorNamespace(TestExtensionContext tes
}



private ConnectionHolder findTestConnection(TestExtensionContext testExtensionContext) {
Class<?> testClass = testExtensionContext.getTestClass().get();
try {
for (Field field : testClass.getDeclaredFields()) {
if (field.getType() == ConnectionHolder.class) {
if (!field.isAccessible()) {
field.setAccessible(true);
}
ConnectionHolder connectionHolder = ConnectionHolder.class.cast(field.get(testExtensionContext.getTestInstance()));
if (connectionHolder == null || connectionHolder.getConnection() == null) {
throw new RuntimeException("ConnectionHolder not initialized correctly");
}
return connectionHolder;
Optional<Field> fieldFound = Arrays.stream(testClass.getDeclaredFields()).
filter(f -> f.getType() == ConnectionHolder.class).
findFirst();

if (fieldFound.isPresent()) {
Field field = fieldFound.get();
if (!field.isAccessible()) {
field.setAccessible(true);
}
ConnectionHolder connectionHolder = ConnectionHolder.class.cast(field.get(testExtensionContext.getTestInstance()));
if (connectionHolder == null || connectionHolder.getConnection() == null) {
throw new RuntimeException("ConnectionHolder not initialized correctly");
}
return connectionHolder;
}

for (Method method : testClass.getDeclaredMethods()) {
if (method.getReturnType() == ConnectionHolder.class) {
if (!method.isAccessible()) {
method.setAccessible(true);
}
ConnectionHolder connectionHolder = ConnectionHolder.class.cast(method.invoke(testExtensionContext.getTestInstance()));
if (connectionHolder == null || connectionHolder.getConnection() == null) {
throw new RuntimeException("ConnectionHolder not initialized correctly");
}
return connectionHolder;
//try to get connection from method

Optional<Method> methodFound = Arrays.stream(testClass.getDeclaredMethods()).
filter(m -> m.getReturnType() == ConnectionHolder.class).
findFirst();

if (methodFound.isPresent()) {
Method method = methodFound.get();
if (!method.isAccessible()) {
method.setAccessible(true);
}
ConnectionHolder connectionHolder = ConnectionHolder.class.cast(method.invoke(testExtensionContext.getTestInstance()));
if (connectionHolder == null || connectionHolder.getConnection() == null) {
throw new RuntimeException("ConnectionHolder not initialized correctly");
}
return connectionHolder;
}


} catch (Exception e) {
throw new RuntimeException("Could not get database connection for test " + testClass, e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.github.dbunit.junit5;

import com.github.dbunit.junit5.model.User;
import com.github.dbunit.rules.api.connection.ConnectionHolder;
import com.github.dbunit.rules.api.dataset.DataSet;
import com.github.dbunit.rules.api.dataset.ExpectedDataSet;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.platform.runner.JUnitPlatform;
import org.junit.runner.RunWith;

import java.util.List;

import static com.github.dbunit.rules.util.EntityManagerProvider.*;
import static org.assertj.core.api.Assertions.assertThat;

/**
* Created by pestano on 28/08/16.
*/
@ExtendWith(DBUnitExtension.class)
@RunWith(JUnitPlatform.class)
public class DBUnitJUnit5WithMethodConnectionHolderTest {

//DBUnitExtension will get connection by reflection so either declare a field or a method with ConncetionHolder as return typr
private ConnectionHolder getConnection(){
return () -> instance("junit5-pu").connection();
}

@Test
@DataSet("users.yml")
public void shouldListUsers() {
List<User> users = em().createQuery("select u from User u").getResultList();
assertThat(users).isNotNull().isNotEmpty().hasSize(2);
}

@Test
@DataSet(cleanBefore=true) //avoid conflict with other tests
public void shouldInsertUser() {
User user = new User();
user.setName("user");
user.setName("@rmpestano");
tx().begin();
em().persist(user);
tx().commit();
User insertedUser = (User)em().createQuery("select u from User u where u.name = '@rmpestano'").getSingleResult();
assertThat(insertedUser).isNotNull();
assertThat(insertedUser.getId()).isNotNull();
}

@Test
@DataSet("users.yml") //no need for clean before because DBUnit uses CLEAN_INSERT seeding strategy which clears involved tables before seeding
public void shouldUpdateUser() {
User user = (User) em().createQuery("select u from User u where u.id = 1").getSingleResult();
assertThat(user).isNotNull();
assertThat(user.getName()).isEqualTo("@realpestano");
//tx().begin();
user.setName("@rmpestano");
em().merge(user);
//tx().commit(); //no needed because of first level cache
User updatedUser = getUser(1L);
assertThat(updatedUser).isNotNull();
assertThat(updatedUser.getName()).isEqualTo("@rmpestano");
}

@Test
@DataSet(value = "users.yml", transactional = true)
@ExpectedDataSet("expectedUser.yml")
public void shouldDeleteUser() {
User user = (User) em().createQuery("select u from User u where u.id = 1").getSingleResult();
assertThat(user).isNotNull();
assertThat(user.getName()).isEqualTo("@realpestano");
em().remove(user);
}


public User getUser(Long id){
return (User) em().createQuery("select u from User u where u.id = :id").
setParameter("id", id).getSingleResult();
}
}

0 comments on commit 46f9814

Please sign in to comment.