diff --git a/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java b/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java index 89e2ae44..6385bc7f 100644 --- a/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java +++ b/common/client/src/main/java/zingg/common/client/ILabelDataViewHelper.java @@ -8,7 +8,7 @@ public interface ILabelDataViewHelper { List getClusterIds(ZFrame lines); - List getDisplayColumns(ZFrame lines, IArguments args); +// List getDisplayColumns(ZFrame lines, IArguments args); ZFrame getCurrentPair(ZFrame lines, int index, List clusterIds, ZFrame clusterLines); diff --git a/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java b/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java index d359b6c0..f0cf06f8 100644 --- a/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java +++ b/common/client/src/main/java/zingg/common/client/cols/FieldDefSelectedCols.java @@ -4,27 +4,38 @@ import java.util.List; import zingg.common.client.FieldDefinition; +import zingg.common.client.MatchType; public class FieldDefSelectedCols extends SelectedCols { - public FieldDefSelectedCols(List fieldDefs, boolean showConcise) { + protected FieldDefSelectedCols() { + + } + + public FieldDefSelectedCols(List fieldDefs, boolean showConcise) { + List colList = getColList(fieldDefs, showConcise); + setCols(colList); + } + + protected List getColList(List fieldDefs) { + return getColList(fieldDefs,false); + } - List namedList = new ArrayList<>(); + protected List getColList(List fieldDefs, boolean showConcise) { + List namedList = new ArrayList(); for (FieldDefinition fieldDef : fieldDefs) { - if (showConcise && fieldDef.isDontUse()) { + if (showConcise && fieldDef.matchType.contains(MatchType.DONT_USE)) { continue; } namedList.add(fieldDef); } - - namedList.add(new FieldDefinition()); List stringList = convertNamedListToStringList(namedList); - setCols(stringList); - } + return stringList; + } - private List convertNamedListToStringList(List namedList) { - List stringList = new ArrayList<>(); + protected List convertNamedListToStringList(List namedList) { + List stringList = new ArrayList(); for (FieldDefinition named : namedList) { stringList.add(named.getName()); } diff --git a/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java b/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java index 8511ea43..62f5aac7 100644 --- a/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java +++ b/common/client/src/main/java/zingg/common/client/cols/ZidAndFieldDefSelector.java @@ -1,16 +1,24 @@ package zingg.common.client.cols; -import java.util.Arrays; import java.util.List; +import zingg.common.client.FieldDefinition; import zingg.common.client.util.ColName; -public class ZidAndFieldDefSelector extends SelectedCols { +public class ZidAndFieldDefSelector extends FieldDefSelectedCols { - public ZidAndFieldDefSelector(String[] fieldDefs) { + public ZidAndFieldDefSelector(List fieldDefs) { + this(fieldDefs, true, false); + } + + public ZidAndFieldDefSelector(List fieldDefs, boolean includeZid, boolean showConcise) { + List colList = getColList(fieldDefs, showConcise); + + if (includeZid) colList.add(0, ColName.ID_COL); + + colList.add(ColName.SOURCE_COL); + + setCols(colList); + } - List fieldDefList = Arrays.asList(fieldDefs); - fieldDefList.add(0, ColName.ID_COL); - setCols(fieldDefList); - } } \ No newline at end of file diff --git a/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java b/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java index 9948fd4f..d5bd5970 100644 --- a/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java +++ b/common/core/src/main/java/zingg/common/core/executor/LabelDataViewHelper.java @@ -6,11 +6,9 @@ import org.apache.commons.logging.LogFactory; import zingg.common.client.ClientOptions; -import zingg.common.client.IArguments; import zingg.common.client.ILabelDataViewHelper; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; -import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; import zingg.common.core.context.Context; @@ -39,11 +37,11 @@ public List getClusterIds(ZFrame lines) { } - @Override - public List getDisplayColumns(ZFrame lines, IArguments args) { - return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); - } - +// @Override +// public List getDisplayColumns(ZFrame lines, IArguments args) { +// return getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); +// } +// @Override public ZFrame getCurrentPair(ZFrame lines, int index, List clusterIds, ZFrame clusterLines) { diff --git a/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java b/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java index 0143dfd2..f8049329 100644 --- a/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java +++ b/common/core/src/main/java/zingg/common/core/executor/LabelUpdater.java @@ -1,6 +1,5 @@ package zingg.common.core.executor; -import java.util.List; import java.util.Scanner; import org.apache.commons.logging.Log; @@ -8,6 +7,7 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; +import zingg.common.client.cols.ZidAndFieldDefSelector; import zingg.common.client.options.ZinggOptions; import zingg.common.client.pipe.Pipe; import zingg.common.client.util.ColName; @@ -125,14 +125,14 @@ protected ZFrame getUpdatedRecords(ZFrame updatedRecords, int } protected int getUserInput(ZFrame lines,ZFrame currentPair,String cluster_id) { - - List displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); - +// List displayCols = getDSUtil().getFieldDefColumns(lines, args, false, args.getShowConcise()); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); int matchFlag = currentPair.getAsInt(currentPair.head(),ColName.MATCH_FLAG_COL); String preMsg = String.format("\n\tThe record pairs belonging to the input cluster id %s are:", cluster_id); String matchType = LabelMatchType.get(matchFlag).msg; String postMsg = String.format("\tThe above pair is labeled as %s\n", matchType); - int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg); +// int selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), preMsg, postMsg); + int selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), preMsg, postMsg); getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT); getTrainingDataModel().updateLabellerStat(matchFlag, -1*INCREMENT); getLabelDataViewHelper().printMarkedRecordsStat( diff --git a/common/core/src/main/java/zingg/common/core/executor/Labeller.java b/common/core/src/main/java/zingg/common/core/executor/Labeller.java index f58020a1..3c496445 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Labeller.java +++ b/common/core/src/main/java/zingg/common/core/executor/Labeller.java @@ -10,6 +10,7 @@ import zingg.common.client.ITrainingDataModel; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; +import zingg.common.client.cols.ZidAndFieldDefSelector; import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; @@ -79,7 +80,8 @@ public ZFrame processRecordsCli(ZFrame lines) throws ZinggClientE ); lines = lines.cache(); - List displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args); +// List displayCols = getLabelDataViewHelper().getDisplayColumns(lines, args); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition(), false, args.getShowConcise()); //have to introduce as snowframe can not handle row.getAs with column //name and row and lines are out of order for the code to work properly //snow getAsString expects row to have same struc as dataframe which is @@ -104,7 +106,8 @@ public ZFrame processRecordsCli(ZFrame lines) throws ZinggClientE msg2 = getLabelDataViewHelper().getMsg2(prediction, score); //String msgHeader = msg1 + msg2; - selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2); +// selectedOption = displayRecordsAndGetUserInput(getDSUtil().select(currentPair, displayCols), msg1, msg2); + selectedOption = displayRecordsAndGetUserInput(currentPair.select(zidAndFieldDefSelector.getCols()), msg1, msg2); getTrainingDataModel().updateLabellerStat(selectedOption, INCREMENT); getLabelDataViewHelper().printMarkedRecordsStat( getTrainingDataModel().getPositivePairsCount(), diff --git a/common/core/src/main/java/zingg/common/core/executor/Matcher.java b/common/core/src/main/java/zingg/common/core/executor/Matcher.java index dfcd050d..17109e26 100644 --- a/common/core/src/main/java/zingg/common/core/executor/Matcher.java +++ b/common/core/src/main/java/zingg/common/core/executor/Matcher.java @@ -8,6 +8,7 @@ import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; +import zingg.common.client.cols.ZidAndFieldDefSelector; import zingg.common.client.options.ZinggOptions; import zingg.common.client.util.ColName; import zingg.common.client.util.ColValues; @@ -35,7 +36,9 @@ public ZFrame getTestData() throws ZinggClientException{ } public ZFrame getFieldDefColumnsDS(ZFrame testDataOriginal) { - return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true); + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition()); + return testDataOriginal.select(zidAndFieldDefSelector.getCols()); +// return getDSUtil().getFieldDefColumnsDS(testDataOriginal, args, true); } @@ -46,13 +49,7 @@ public ZFrame getBlocked( ZFrame testData) throws Exception, Zin ZFrame blocked1 = blocked.repartition(args.getNumPartitions(), blocked.col(ColName.HASH_COL)); //.cache(); return blocked1; } - - - public ZFrame getBlocks(ZFrameblocked) throws Exception{ - return getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache(); - } - public ZFrame getBlocks(ZFrameblocked, ZFramebAll) throws Exception{ ZFramejoinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache(); /*ZFramejoinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL) diff --git a/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java b/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java index 625750a5..3c291968 100644 --- a/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java +++ b/common/core/src/main/java/zingg/common/core/executor/TrainingDataFinder.java @@ -1,10 +1,13 @@ package zingg.common.core.executor; +import java.util.Arrays; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import zingg.common.client.ZFrame; import zingg.common.client.ZinggClientException; +import zingg.common.client.cols.ZidAndFieldDefSelector; import zingg.common.client.options.ZinggOptions; import zingg.common.client.pipe.Pipe; import zingg.common.client.util.ColName; @@ -79,7 +82,7 @@ public void execute() throws ZinggClientException { if (negPairs!= null) negPairs = negPairs.cache(); //create random samples for blocking ZFrame sampleOrginal = data.sample(false, args.getLabelDataSampleSize()).repartition(args.getNumPartitions()).cache(); - sampleOrginal = getDSUtil().getFieldDefColumnsDS(sampleOrginal, args, true); + sampleOrginal = getFieldDefColumnsDS(sampleOrginal); LOG.info("Preprocessing DS for stopWords"); ZFrame sample = getStopWords().preprocessForStopWords(sampleOrginal); @@ -188,7 +191,7 @@ public ZFrame getPositiveSamples(ZFrame data) throws Exception { } ZFrame posSample = data.sample(false, args.getLabelDataSampleSize()); //select only those columns which are mentioned in the field definitions - posSample = getDSUtil().getFieldDefColumnsDS(posSample, args, true); + posSample = getFieldDefColumnsDS(posSample); if (LOG.isDebugEnabled()) { LOG.debug("Sampled " + posSample.count()); } @@ -202,8 +205,13 @@ public ZFrame getPositiveSamples(ZFrame data) throws Exception { return posPairs; } + protected ZFrame getFieldDefColumnsDS(ZFrame data) { + ZidAndFieldDefSelector zidAndFieldDefSelector = new ZidAndFieldDefSelector(args.getFieldDefinition()); + String[] cols = zidAndFieldDefSelector.getCols(); + return data.select(cols); + //return getDSUtil().getFieldDefColumnsDS(data, args, true); + } + protected abstract StopWordsRemover getStopWords(); - - }