Skip to content

Commit

Permalink
Merge pull request #803 from zinggAI/selectCols
Browse files Browse the repository at this point in the history
use IPairBuilder for building pairs
  • Loading branch information
sonalgoyal committed Mar 6, 2024
2 parents 5aab06a + e910eb8 commit 9f33945
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 29 deletions.
22 changes: 14 additions & 8 deletions common/core/src/main/java/zingg/common/core/executor/Linker.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,27 @@
import zingg.common.client.options.ZinggOptions;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.pairs.SelfPairBuilderSourceSensitive;



public abstract class Linker<S,D,R,C,T> extends Matcher<S,D,R,C,T> {

private static final long serialVersionUID = 1L;
protected static String name = "zingg.Linker";
public static final Log LOG = LogFactory.getLog(Linker.class);

public Linker() {
setZinggOption(ZinggOptions.LINK);
}

public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C> blocked, ZFrame<D,R,C> bAll) throws Exception{
// THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE
LOG.info("in getBlocks, blocked count is " + blocked.count());
return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache();
}


@Override
public ZFrame<D,R,C> selectColsFromBlocked(ZFrame<D,R,C> blocked) {
return blocked;
}

@Override
public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws ZinggClientException {
try {
// input dupes are pairs
Expand All @@ -53,12 +52,19 @@ public void writeOutput(ZFrame<D,R,C> sampleOrginal, ZFrame<D,R,C> dupes) throws
}
}

@Override
public ZFrame<D,R,C> getDupesActualForGraph(ZFrame<D,R,C> dupes) {
ZFrame<D,R,C> dupesActual = dupes
.filter(dupes.equalTo(ColName.PREDICTION_COL, ColValues.IS_MATCH_PREDICTION));
return dupesActual;
}


@Override
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilderSourceSensitive<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

}
42 changes: 21 additions & 21 deletions common/core/src/main/java/zingg/common/core/executor/Matcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import zingg.common.core.block.Canopy;
import zingg.common.core.block.Tree;
import zingg.common.core.model.Model;
import zingg.common.core.pairs.IPairBuilder;
import zingg.common.core.pairs.SelfPairBuilder;
import zingg.common.core.preprocess.StopWordsRemover;
import zingg.common.core.util.Analytics;
import zingg.common.core.util.Metric;
Expand All @@ -25,6 +27,7 @@ public abstract class Matcher<S,D,R,C,T> extends ZinggBase<S,D,R,C,T>{
protected static String name = "zingg.Matcher";
public static final Log LOG = LogFactory.getLog(Matcher.class);

protected IPairBuilder<S, D, R, C> iPairBuilder;

public Matcher() {
setZinggOption(ZinggOptions.MATCH);
Expand All @@ -50,26 +53,8 @@ public ZFrame<D,R,C> getBlocked( ZFrame<D,R,C> testData) throws Exception, Zin
return blocked1;
}

public ZFrame<D,R,C> getBlocks(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
ZFrame<D,R,C>joinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
/*ZFrame<D,R,C>joinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL)
.selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid");
*/
//joinH.show();
joinH = joinH.filter(joinH.gt(ColName.ID_COL));
LOG.warn("Num comparisons " + joinH.count());
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL));
bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.ID_COL);
LOG.warn("Joining with actual values");
//joinH.show();
bAll = getDSUtil().getPrefixedColumnsDS(bAll);
//bAll.show();
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL);
LOG.warn("Joining again with actual values");
//joinH.show();
return joinH;
public ZFrame<D,R,C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception{
return getIPairBuilder().getPairs(blocked, bAll);
}

protected abstract Model getModel() throws ZinggClientException;
Expand All @@ -91,7 +76,7 @@ protected ZFrame<D,R,C> predictOnBlocks(ZFrame<D,R,C>blocks) throws Exception, Z
}

protected ZFrame<D,R,C> getActualDupes(ZFrame<D,R,C> blocked, ZFrame<D,R,C> testData) throws Exception, ZinggClientException{
ZFrame<D,R,C> blocks = getBlocks(selectColsFromBlocked(blocked), testData);
ZFrame<D,R,C> blocks = getPairs(selectColsFromBlocked(blocked), testData);
ZFrame<D,R,C>dupesActual = predictOnBlocks(blocks);
return getDupesActualForGraph(dupesActual);
}
Expand Down Expand Up @@ -285,6 +270,21 @@ protected ZFrame<D,R,C> selectColsFromDupes(ZFrame<D,R,C>dupesActual) {

protected abstract StopWordsRemover<S,D,R,C,T> getStopWords();

/**
* Each sub class of matcher can inject it's own iPairBuilder implementation
* @return
*/
public IPairBuilder<S, D, R, C> getIPairBuilder() {
if(iPairBuilder==null) {
iPairBuilder = new SelfPairBuilder<S, D, R, C> (getDSUtil(),args);
}
return iPairBuilder;
}

public void setIPairBuilder(IPairBuilder<S, D, R, C> iPairBuilder) {
this.iPairBuilder = iPairBuilder;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package zingg.common.core.pairs;

import zingg.common.client.ZFrame;

public interface IPairBuilder<S, D, R, C> {

public ZFrame<D, R, C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package zingg.common.core.pairs;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZFrame;
import zingg.common.client.util.ColName;
import zingg.common.client.util.DSUtil;

public class SelfPairBuilder<S, D, R, C> implements IPairBuilder<S, D, R, C> {

protected DSUtil<S, D, R, C> dsUtil;
public static final Log LOG = LogFactory.getLog(SelfPairBuilder.class);
protected IArguments args;

public SelfPairBuilder(DSUtil<S, D, R, C> dsUtil, IArguments args) {
this.dsUtil = dsUtil;
this.args = args;
}

@Override
public ZFrame<D, R, C> getPairs(ZFrame<D,R,C>blocked, ZFrame<D,R,C>bAll) throws Exception {
ZFrame<D,R,C>joinH = getDSUtil().joinWithItself(blocked, ColName.HASH_COL, true).cache();
/*ZFrame<D,R,C>joinH = blocked.as("first").joinOnCol(blocked.as("second"), ColName.HASH_COL)
.selectExpr("first.z_zid as z_zid", "second.z_zid as z_z_zid");
*/
//joinH.show();
joinH = joinH.filter(joinH.gt(ColName.ID_COL));
LOG.warn("Num comparisons " + joinH.count());
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.ID_COL));
bAll = bAll.repartition(args.getNumPartitions(), bAll.col(ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.ID_COL);
LOG.warn("Joining with actual values");
//joinH.show();
bAll = getDSUtil().getPrefixedColumnsDS(bAll);
//bAll.show();
joinH = joinH.repartition(args.getNumPartitions(), joinH.col(ColName.COL_PREFIX + ColName.ID_COL));
joinH = joinH.joinOnCol(bAll, ColName.COL_PREFIX + ColName.ID_COL);
LOG.warn("Joining again with actual values");
//joinH.show();
return joinH;
}

public DSUtil<S, D, R, C> getDSUtil() {
return dsUtil;
}

public void setDSUtil(DSUtil<S, D, R, C> dsUtil) {
this.dsUtil = dsUtil;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package zingg.common.core.pairs;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import zingg.common.client.IArguments;
import zingg.common.client.ZFrame;
import zingg.common.client.util.ColName;
import zingg.common.client.util.DSUtil;

public class SelfPairBuilderSourceSensitive<S, D, R, C> extends SelfPairBuilder<S, D, R, C> {

public static final Log LOG = LogFactory.getLog(SelfPairBuilderSourceSensitive.class);

public SelfPairBuilderSourceSensitive(DSUtil<S, D, R, C> dsUtil, IArguments args) {
super(dsUtil, args);
}

@Override
public ZFrame<D,R,C> getPairs(ZFrame<D,R,C> blocked, ZFrame<D,R,C> bAll) throws Exception{
// THIS LOG IS NEEDED FOR PLAN CALCULATION USING COUNT, DO NOT REMOVE
LOG.info("in getBlocks, blocked count is " + blocked.count());
return getDSUtil().joinWithItselfSourceSensitive(blocked, ColName.HASH_COL, args).cache();
}

}

0 comments on commit 9f33945

Please sign in to comment.