Skip to content

Commit

Permalink
Merge pull request #628 from zinggAI/obviousDupes
Browse files Browse the repository at this point in the history
Obvious dupes performance optimisation
  • Loading branch information
sonalgoyal committed Jul 31, 2023
2 parents bb3102c + 9dba2c9 commit 8430d42
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 33 deletions.
12 changes: 8 additions & 4 deletions common/client/src/main/java/zingg/common/client/ZFrame.java
Expand Up @@ -72,6 +72,8 @@ public interface ZFrame<D, R, C> {
public ZFrame<D, R, C> groupByCount(String colName, String countColName);

public ZFrame<D, R, C> union(ZFrame<D, R, C> other);

public ZFrame<D, R, C> unionAll(ZFrame<D, R, C> other);

public ZFrame<D, R, C> unionByName(ZFrame<D, R, C> other, boolean flag);

Expand All @@ -91,6 +93,8 @@ public interface ZFrame<D, R, C> {
public ZFrame<D, R, C> coalesce(int num);

public C gt(String c);

public C gt(ZFrame<D, R, C> other, String c);

public C gt(String c, double val);

Expand Down Expand Up @@ -155,12 +159,12 @@ public interface ZFrame<D, R, C> {

public ZFrame<D, R, C> filterNullCond(String colName);

public C getObviousDupesFilter(String obviousDupeString);
public C getObviousDupesFilter(String obviousDupeString, C extraAndCond);

public C getObviousDupesFilter(ZFrame<D, R, C> dfToJoin, String obviousDupeString);
public C getObviousDupesFilter(ZFrame<D, R, C> dfToJoin, String obviousDupeString, C extraAndCond);

public C getReverseObviousDupesFilter(String obviousDupeString);
public C getReverseObviousDupesFilter(String obviousDupeString, C extraAndCond);

public C getReverseObviousDupesFilter(ZFrame<D, R, C> dfToJoin, String obviousDupeString);
public C getReverseObviousDupesFilter(ZFrame<D, R, C> dfToJoin, String obviousDupeString, C extraAndCond);

}
72 changes: 53 additions & 19 deletions common/core/src/main/java/zingg/common/core/executor/Matcher.java
Expand Up @@ -125,12 +125,16 @@ public void execute() throws ZinggClientException {
//get obvious dupes
ZFrame<D, R, C> obvDupePairs = getObvDupePairs(blocked);
if (obvDupePairs != null) {
LOG.info("obvDupePairs count " + obvDupePairs.count());
blocks = removeObvDupesFromBlocks(blocks);
long obvDupeCount = obvDupePairs.count();
LOG.info("obvDupePairs count " + obvDupeCount);
if (obvDupeCount > 0) {
blocks = removeObvDupesFromBlocks(blocks);
}
}

//send remaining to model
Model model = getModel();

//blocks.cache().withColumn("partition_id", functions.spark_partition_id())
// .groupBy("partition_id").agg(functions.count("z_id")).ias("zid").orderBy("partition_id").;
/*
Expand All @@ -139,17 +143,24 @@ public void execute() throws ZinggClientException {
blocksRe.withColumn("partition_id", functions.spark_partition_id())
.groupBy("partition_id").agg(functions.count("z_zid")).as("zid").orderBy("partition_id").toJavaRDD().saveAsTextFile("/tmp/zblocksPart");
*/
ZFrame<D,R,C>dupes = model.predict(blocks); //.exceptAll(allEqual));

ZFrame<D,R,C>dupes = model.predict(blocks);

//.exceptAll(allEqual));

//allEqual = massageAllEquals(allEqual);
dupes = addObvDupes(obvDupePairs, dupes);

if (LOG.isDebugEnabled()) {
LOG.debug("Found dupes " + dupes.count());
}
//dupes = dupes.cache();

//allEqual = allEqual.cache();
//writeOutput(blocked, dupes.union(allEqual).cache());

ZFrame<D,R,C>dupesActual = getDupesActualForGraph(dupes);
dupesActual = addObvDupes(obvDupePairs, dupesActual);

//dupesActual.explain();
//dupesActual.toJavaRDD().saveAsTextFile("/tmp/zdupes");

Expand All @@ -161,39 +172,62 @@ public void execute() throws ZinggClientException {
throw new ZinggClientException(e.getMessage());
}
}

protected ZFrame<D, R, C> addObvDupes(ZFrame<D, R, C> obvDupePairs, ZFrame<D, R, C> dupes) {
protected ZFrame<D, R, C> addObvDupes(ZFrame<D, R, C> obvDupePairs, ZFrame<D, R, C> dupesActual) {
if (obvDupePairs != null) {
// unionByName as positions may differ
dupes = dupes.unionByName(obvDupePairs, false);
// ensure same columns in both
obvDupePairs = selectColsFromDupes(obvDupePairs);
dupesActual = dupesActual.unionAll(obvDupePairs);
}
return dupes;
return dupesActual;
}

protected ZFrame<D, R, C> removeObvDupesFromBlocks(ZFrame<D, R, C> blocks) {
LOG.info("blocks count before removing obvDupePairs " + blocks.count());
C reverseOBVDupeDFFilter = blocks.getReverseObviousDupesFilter(args.getObviousDupeCondition());
C reverseOBVDupeDFFilter = blocks.getReverseObviousDupesFilter(args.getObviousDupeCondition(),null);
if (reverseOBVDupeDFFilter != null) {
// remove dupes as already considered in obvDupePairs
blocks = blocks.filter(reverseOBVDupeDFFilter);
}
LOG.info("blocks count after removing obvDupePairs " + blocks.count());
return blocks;
}

protected ZFrame<D,R,C> getObvDupePairs(ZFrame<D,R,C> blocked) {

ZFrame<D,R,C> prefixedColsDF = getDSUtil().getPrefixedColumnsDS(blocked);
C obvDupeDFFilter = blocked.getObviousDupesFilter(prefixedColsDF,args.getObviousDupeCondition());
if (obvDupeDFFilter == null) {
String obviousDupeString = args.getObviousDupeCondition();

if (obviousDupeString == null || obviousDupeString.trim().isEmpty()) {
return null;
}

ZFrame<D,R,C> prefixBlocked = getDSUtil().getPrefixedColumnsDS(blocked);
C gtCond = blocked.gt(prefixBlocked,ColName.ID_COL);

ZFrame<D,R,C> onlyIds = null;

// split on || (orSeperator)
String[] obvDupeORConditions = obviousDupeString.trim().split(ZFrame.orSeperator);
// loop thru the values and build a filter condition
for (int i = 0; i < obvDupeORConditions.length; i++) {

C obvDupeDFFilter = blocked.getObviousDupesFilter(prefixBlocked,obvDupeORConditions[i],gtCond);
ZFrame<D,R,C> onlyIdsTemp = blocked
.joinOnCol(prefixBlocked, obvDupeDFFilter).select(ColName.ID_COL, ColName.COL_PREFIX + ColName.ID_COL);

if(onlyIds==null) {
onlyIds = onlyIdsTemp;
} else {
onlyIds = onlyIds.unionAll(onlyIdsTemp);
}

}

ZFrame<D, R, C> obvDupePairs = blocked.joinOnCol(prefixedColsDF, obvDupeDFFilter).cache();
obvDupePairs = obvDupePairs.filter(obvDupePairs.gt(ColName.ID_COL));
obvDupePairs = massageAllEquals(obvDupePairs);
// remove duplicate pairs
onlyIds = onlyIds.distinct();
onlyIds = massageAllEquals(onlyIds);
onlyIds = onlyIds.cache();

return obvDupePairs;
return onlyIds;
}

public void writeOutput( ZFrame<D,R,C> blocked, ZFrame<D,R,C> dupesActual) throws ZinggClientException {
Expand Down Expand Up @@ -308,7 +342,7 @@ public void writeOutput( ZFrame<D,R,C> blocked, ZFrame<D,R,C> dupesActual) th
}

protected ZFrame<D,R,C>getDupesActualForGraph(ZFrame<D,R,C>dupes) {
ZFrame<D,R,C> dupesActual = selectColsFromDupes(dupes);
dupes = selectColsFromDupes(dupes);
LOG.debug("dupes al");
if (LOG.isDebugEnabled()) dupes.show();
return dupes.filter(dupes.equalTo(ColName.PREDICTION_COL,ColValues.IS_MATCH_PREDICTION));
Expand Down
36 changes: 28 additions & 8 deletions spark/client/src/main/java/zingg/spark/client/SparkFrame.java
Expand Up @@ -174,6 +174,11 @@ public ZFrame<Dataset<Row>, Row, Column> union(ZFrame<Dataset<Row>, Row, Column>
return new SparkFrame(df.union(other.df()));
}

@Override
public ZFrame<Dataset<Row>, Row, Column> unionAll(ZFrame<Dataset<Row>, Row, Column> other) {
return new SparkFrame(df.unionAll(other.df()));
}

public ZFrame<Dataset<Row>, Row, Column> unionByName(ZFrame<Dataset<Row>, Row, Column> other, boolean flag) {
return new SparkFrame(df.unionByName(other.df(), flag));
}
Expand All @@ -197,10 +202,16 @@ public ZFrame<Dataset<Row>, Row, Column> repartition(int nul, Column c){
return new SparkFrame(df.repartition(nul, c));
}

@Override
public Column gt(String c) {
return df.col(c).gt(df.col(ColName.COL_PREFIX + c));
return gt(this,c);
}

@Override
public Column gt(ZFrame<Dataset<Row>, Row, Column> other, String c) {
return df.col(c).gt(other.col(ColName.COL_PREFIX + c));
}

@Override
public Column gt(String c, double val) {
return df.col(c).gt(val);
Expand Down Expand Up @@ -383,8 +394,8 @@ public ZFrame<Dataset<Row>, Row, Column> filterNullCond(String colName) {
* @param obviousDupeString
* @return
*/
public Column getObviousDupesFilter(String obviousDupeString) {
return getObviousDupesFilter(this,obviousDupeString);
public Column getObviousDupesFilter(String obviousDupeString, Column extraAndCond) {
return getObviousDupesFilter(this,obviousDupeString,extraAndCond);
}

/**
Expand All @@ -395,7 +406,7 @@ public Column getObviousDupesFilter(String obviousDupeString) {
* @return
*/
@Override
public Column getObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin, String obviousDupeString) {
public Column getObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin, String obviousDupeString, Column extraAndCond) {

if (dfToJoin==null || obviousDupeString == null || obviousDupeString.trim().isEmpty()) {
return null;
Expand Down Expand Up @@ -456,6 +467,15 @@ public Column getObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin,
}

}

if (extraAndCond != null) {
if (filterExpr != null) {
filterExpr = filterExpr.and(extraAndCond);
} else {
filterExpr = extraAndCond;
}
}

return filterExpr;
}

Expand All @@ -467,8 +487,8 @@ public Column getObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin,
* @return
*/
@Override
public Column getReverseObviousDupesFilter(String obviousDupeString) {
return getReverseObviousDupesFilter(this,obviousDupeString);
public Column getReverseObviousDupesFilter(String obviousDupeString, Column extraAndCond) {
return getReverseObviousDupesFilter(this,obviousDupeString,extraAndCond);
}

/**
Expand All @@ -479,8 +499,8 @@ public Column getReverseObviousDupesFilter(String obviousDupeString) {
* @return
*/
@Override
public Column getReverseObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin, String obviousDupeString) {
return functions.not(getObviousDupesFilter(dfToJoin,obviousDupeString));
public Column getReverseObviousDupesFilter(ZFrame<Dataset<Row>, Row, Column> dfToJoin, String obviousDupeString, Column extraAndCond) {
return functions.not(getObviousDupesFilter(dfToJoin,obviousDupeString,extraAndCond));
}

}
20 changes: 18 additions & 2 deletions spark/client/src/test/java/zingg/client/TestSparkFrame.java
Expand Up @@ -325,20 +325,36 @@ public void testGetObviousDupesFilter() throws ZinggClientException {

SparkFrame posDF = getPosPairDF();

Column filter = posDF.getObviousDupesFilter("name & event & comment | dob | comment & year");
Column filter = posDF.getObviousDupesFilter("name & event & comment | dob | comment & year",null);

String expectedCond = "(((((((name = z_name) AND (name IS NOT NULL)) AND (z_name IS NOT NULL)) AND (((event = z_event) AND (event IS NOT NULL)) AND (z_event IS NOT NULL))) AND (((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL))) OR (((dob = z_dob) AND (dob IS NOT NULL)) AND (z_dob IS NOT NULL))) OR ((((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL)) AND (((year = z_year) AND (year IS NOT NULL)) AND (z_year IS NOT NULL))))";

assertEquals(expectedCond,filter.toString());

}

@Test
public void testGetObviousDupesFilterWithExtraCond() throws ZinggClientException {

SparkFrame posDF = getPosPairDF();
Column gtCond = posDF.gt("z_zid");

Column filter = posDF.getObviousDupesFilter("name & event & comment | dob | comment & year",gtCond);

System.out.println(filter.toString());

String expectedCond = "((((((((name = z_name) AND (name IS NOT NULL)) AND (z_name IS NOT NULL)) AND (((event = z_event) AND (event IS NOT NULL)) AND (z_event IS NOT NULL))) AND (((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL))) OR (((dob = z_dob) AND (dob IS NOT NULL)) AND (z_dob IS NOT NULL))) OR ((((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL)) AND (((year = z_year) AND (year IS NOT NULL)) AND (z_year IS NOT NULL)))) AND (z_zid > z_z_zid))";

assertEquals(expectedCond,filter.toString());

}

@Test
public void testGetReverseObviousDupesFilter() throws ZinggClientException {

SparkFrame posDF = getPosPairDF();

Column filter = posDF.getReverseObviousDupesFilter("name & event & comment | dob | comment & year");
Column filter = posDF.getReverseObviousDupesFilter("name & event & comment | dob | comment & year",null);

String expectedCond = "(NOT (((((((name = z_name) AND (name IS NOT NULL)) AND (z_name IS NOT NULL)) AND (((event = z_event) AND (event IS NOT NULL)) AND (z_event IS NOT NULL))) AND (((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL))) OR (((dob = z_dob) AND (dob IS NOT NULL)) AND (z_dob IS NOT NULL))) OR ((((comment = z_comment) AND (comment IS NOT NULL)) AND (z_comment IS NOT NULL)) AND (((year = z_year) AND (year IS NOT NULL)) AND (z_year IS NOT NULL)))))";

Expand Down

0 comments on commit 8430d42

Please sign in to comment.