Skip to content

Commit

Permalink
refactor init to base class
Browse files Browse the repository at this point in the history
  • Loading branch information
vikasgupta78 committed Feb 18, 2024
1 parent e18399e commit 6423793
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ protected void assessAccuracy() throws ZinggClientException {
LOG.info("precision " + (tpCount*1.0d/(tpCount+fpCount)));

Check warning

Code scanning / PMD

Logger calls should be surrounded by log level guards. Warning

Logger calls should be surrounded by log level guards.
LOG.info("recall " + tpCount + " denom " + (tpCount+fnCount) + " overall " + (tpCount*1.0d/(tpCount+fnCount)));

Check warning

Code scanning / PMD

Logger calls should be surrounded by log level guards. Warning

Logger calls should be surrounded by log level guards.

assertTrue(0.8 < (tpCount*1.0d/(tpCount+fpCount)));
assertTrue(0.8 < (tpCount*1.0d/(tpCount+fnCount)));
assertTrue(0.8 < Math.round(tpCount*1.0d/(tpCount+fpCount)));
assertTrue(0.8 < Math.round(tpCount*1.0d/(tpCount+fnCount)));
}

public ZFrame<D, R, C> getOutputData() throws ZinggClientException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,35 @@ public String setupArgs() throws ZinggClientException, IOException {
public void testExecutors() throws ZinggClientException {
List<ExecutorTester<S, D, R, C, T>> executorTesterList = new ArrayList<ExecutorTester<S, D, R, C, T>>();

TrainingDataFinderTester<S, D, R, C, T> tdft = new TrainingDataFinderTester<S, D, R, C, T>(getTrainingDataFinder());
TrainingDataFinder<S, D, R, C, T> trainingDataFinder = getTrainingDataFinder();
trainingDataFinder.init(args);
TrainingDataFinderTester<S, D, R, C, T> tdft = new TrainingDataFinderTester<S, D, R, C, T>(trainingDataFinder);
executorTesterList.add(tdft);

LabellerTester<S, D, R, C, T> lt = new LabellerTester<S, D, R, C, T>(getLabeller());
Labeller<S, D, R, C, T> labeller = getLabeller();
labeller.init(args);
LabellerTester<S, D, R, C, T> lt = new LabellerTester<S, D, R, C, T>(labeller);
executorTesterList.add(lt);

// training and labelling needed twice to get sufficient data
TrainingDataFinderTester<S, D, R, C, T> tdft2 = new TrainingDataFinderTester<S, D, R, C, T>(getTrainingDataFinder());
TrainingDataFinder<S, D, R, C, T> trainingDataFinder2 = getTrainingDataFinder();
trainingDataFinder2.init(args);
TrainingDataFinderTester<S, D, R, C, T> tdft2 = new TrainingDataFinderTester<S, D, R, C, T>(trainingDataFinder2);
executorTesterList.add(tdft2);

LabellerTester<S, D, R, C, T> lt2 = new LabellerTester<S, D, R, C, T>(getLabeller());
Labeller<S, D, R, C, T> labeller2 = getLabeller();
labeller2.init(args);
LabellerTester<S, D, R, C, T> lt2 = new LabellerTester<S, D, R, C, T>(labeller2);
executorTesterList.add(lt2);

TrainerTester<S, D, R, C, T> tt = new TrainerTester<S, D, R, C, T>(getTrainer());
Trainer<S, D, R, C, T> trainer = getTrainer();
trainer.init(args);
TrainerTester<S, D, R, C, T> tt = new TrainerTester<S, D, R, C, T>(trainer);
executorTesterList.add(tt);

MatcherTester<S, D, R, C, T> mt = new MatcherTester(getMatcher());
Matcher<S, D, R, C, T> matcher = getMatcher();
matcher.init(args);
MatcherTester<S, D, R, C, T> mt = new MatcherTester(matcher);
executorTesterList.add(mt);

testExecutors(executorTesterList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,21 @@ public String getConfigFile() {
@Override
protected SparkTrainingDataFinder getTrainingDataFinder() throws ZinggClientException {
SparkTrainingDataFinder stdf = new SparkTrainingDataFinder(ctx);
stdf.init(args);
return stdf;
}
@Override
protected Labeller<SparkSession,Dataset<Row>,Row,Column,DataType> getLabeller() throws ZinggClientException {
JunitSparkLabeller jlbl = new JunitSparkLabeller(ctx);
jlbl.init(args);
return jlbl;
}
@Override
protected SparkTrainer getTrainer() throws ZinggClientException {
SparkTrainer st = new SparkTrainer(ctx);
st.init(args);
return st;
}
@Override
protected SparkMatcher getMatcher() throws ZinggClientException {
SparkMatcher sm = new SparkMatcher(ctx);
sm.init(args);
return sm;
}

Expand Down

0 comments on commit 6423793

Please sign in to comment.