Skip to content

Commit

Permalink
refactor testExecutors to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
vikasgupta78 committed Feb 18, 2024
1 parent 2c7f744 commit e18399e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import zingg.common.client.ZinggClientException;
import zingg.common.client.util.ColName;

public abstract class MatcherTester<S, D, R, C, T> extends ExecutorTester<S, D, R, C, T> {
public class MatcherTester<S, D, R, C, T> extends ExecutorTester<S, D, R, C, T> {

public static final Log LOG = LogFactory.getLog(MatcherTester.class);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package zingg.common.core.executor;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;

import zingg.common.client.ArgumentsUtil;
import zingg.common.client.IArguments;
Expand Down Expand Up @@ -42,6 +44,34 @@ public String setupArgs() throws ZinggClientException, IOException {

public abstract String getConfigFile();


@Test
public void testExecutors() throws ZinggClientException {
List<ExecutorTester<S, D, R, C, T>> executorTesterList = new ArrayList<ExecutorTester<S, D, R, C, T>>();

TrainingDataFinderTester<S, D, R, C, T> tdft = new TrainingDataFinderTester<S, D, R, C, T>(getTrainingDataFinder());
executorTesterList.add(tdft);

LabellerTester<S, D, R, C, T> lt = new LabellerTester<S, D, R, C, T>(getLabeller());
executorTesterList.add(lt);

// training and labelling needed twice to get sufficient data
TrainingDataFinderTester<S, D, R, C, T> tdft2 = new TrainingDataFinderTester<S, D, R, C, T>(getTrainingDataFinder());
executorTesterList.add(tdft2);

LabellerTester<S, D, R, C, T> lt2 = new LabellerTester<S, D, R, C, T>(getLabeller());
executorTesterList.add(lt2);

TrainerTester<S, D, R, C, T> tt = new TrainerTester<S, D, R, C, T>(getTrainer());
executorTesterList.add(tt);

MatcherTester<S, D, R, C, T> mt = new MatcherTester(getMatcher());
executorTesterList.add(mt);

testExecutors(executorTesterList);
}


public void testExecutors(List<ExecutorTester<S, D, R, C, T>> executorTesterList) throws ZinggClientException {
for (ExecutorTester<S, D, R, C, T> executorTester : executorTesterList) {
executorTester.execute();
Expand All @@ -51,4 +81,12 @@ public void testExecutors(List<ExecutorTester<S, D, R, C, T>> executorTesterList

public abstract void tearDown();

protected abstract TrainingDataFinder<S, D, R, C, T> getTrainingDataFinder() throws ZinggClientException;

protected abstract Labeller<S, D, R, C, T> getLabeller() throws ZinggClientException;

protected abstract Trainer<S, D, R, C, T> getTrainer() throws ZinggClientException;

protected abstract Matcher<S, D, R, C, T> getMatcher() throws ZinggClientException;

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -13,16 +11,10 @@
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import zingg.common.client.ZinggClientException;
import zingg.common.core.executor.ExecutorTester;
import zingg.common.core.executor.Labeller;
import zingg.common.core.executor.LabellerTester;
import zingg.common.core.executor.MatcherTester;
import zingg.common.core.executor.TestExecutorsGeneric;
import zingg.common.core.executor.TrainerTester;
import zingg.common.core.executor.TrainingDataFinderTester;
import zingg.spark.core.context.ZinggSparkContext;

public class TestSparkExecutors extends TestExecutorsGeneric<SparkSession,Dataset<Row>,Row,Column,DataType> {
Expand Down Expand Up @@ -52,50 +44,25 @@ public String getConfigFile() {
return CONFIG_FILE;
}

@Test
public void testExecutors() throws ZinggClientException {
List<ExecutorTester<SparkSession,Dataset<Row>,Row,Column,DataType>> executorTesterList = new ArrayList<ExecutorTester<SparkSession,Dataset<Row>,Row,Column,DataType>>();

TrainingDataFinderTester<SparkSession,Dataset<Row>,Row,Column,DataType> tdft = new TrainingDataFinderTester<SparkSession,Dataset<Row>,Row,Column,DataType>(getTrainingDataFinder());
executorTesterList.add(tdft);

LabellerTester<SparkSession,Dataset<Row>,Row,Column,DataType> lt = new LabellerTester<SparkSession,Dataset<Row>,Row,Column,DataType>(getLabeller());
executorTesterList.add(lt);

// training and labelling needed twice to get sufficient data
TrainingDataFinderTester<SparkSession,Dataset<Row>,Row,Column,DataType> tdft2 = new TrainingDataFinderTester<SparkSession,Dataset<Row>,Row,Column,DataType>(getTrainingDataFinder());
executorTesterList.add(tdft2);

LabellerTester<SparkSession,Dataset<Row>,Row,Column,DataType> lt2 = new LabellerTester<SparkSession,Dataset<Row>,Row,Column,DataType>(getLabeller());
executorTesterList.add(lt2);

TrainerTester<SparkSession,Dataset<Row>,Row,Column,DataType> tt = new TrainerTester<SparkSession,Dataset<Row>,Row,Column,DataType>(getTrainer());
executorTesterList.add(tt);

MatcherTester<SparkSession,Dataset<Row>,Row,Column,DataType> mt = new SparkMatcherTester(getMatcher());
executorTesterList.add(mt);

super.testExecutors(executorTesterList);
}

@Override
protected SparkTrainingDataFinder getTrainingDataFinder() throws ZinggClientException {
SparkTrainingDataFinder stdf = new SparkTrainingDataFinder(ctx);
stdf.init(args);
return stdf;
}

@Override
protected Labeller<SparkSession,Dataset<Row>,Row,Column,DataType> getLabeller() throws ZinggClientException {
JunitSparkLabeller jlbl = new JunitSparkLabeller(ctx);
jlbl.init(args);
return jlbl;
}

@Override
protected SparkTrainer getTrainer() throws ZinggClientException {
SparkTrainer st = new SparkTrainer(ctx);
st.init(args);
return st;
}

@Override
protected SparkMatcher getMatcher() throws ZinggClientException {
SparkMatcher sm = new SparkMatcher(ctx);
sm.init(args);
Expand Down

0 comments on commit e18399e

Please sign in to comment.