Skip to content

Commit

Permalink
Merge pull request #103 from zinggAI/0.3.1
Browse files Browse the repository at this point in the history
0.3.1
  • Loading branch information
sonalgoyal committed Jan 2, 2022
2 parents 3543ece + 28dc52b commit ccc90bd
Show file tree
Hide file tree
Showing 11 changed files with 716 additions and 68 deletions.
9 changes: 5 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
FROM docker.io/bitnami/spark:3
FROM docker.io/bitnami/spark:3.1.2
ENV SPARK_MASTER local[*]
ENV ZINGG_HOME /zingg-0.3.0-SNAPSHOT
ENV ZINGG_HOME /zingg-0.3.1-SNAPSHOT
WORKDIR /
USER root
WORKDIR /zingg-0.3.0-SNAPSHOT
RUN curl --location https://github.com/zinggAI/zingg/releases/download/v0.3.0/zingg-0.3.0-SNAPSHOT-spark-3.0.3.tar.gz | \
WORKDIR /zingg-0.3.1-SNAPSHOT
RUN curl --location https://github.com/zinggAI/zingg/releases/download/v0.3.1/zingg-0.3.1-SNAPSHOT-spark-3.1.2.tar.gz | \
tar --extract --gzip --strip=1
RUN chmod -R +rwx /zingg-0.3.1-SNAPSHOT/models
#RUN chmod +x zingg-0.3.0-SNAPSHOT-spark-3.0.3.tar.gz
#RUN tar --extract --gzip --strip=1 /tmp/zingg-0.3.0-SNAPSHOT-spark-3.0.3.tar.gz
USER 1001
Expand Down
2 changes: 1 addition & 1 deletion client/src/main/java/zingg/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ else if (args.getJobId() != -1) {
}

public static void printBanner() {
String versionStr = "0.3";
String versionStr = "0.3.1";
LOG.info("");
LOG.info("********************************************************");
LOG.info("* Zingg AI *");
Expand Down
3 changes: 2 additions & 1 deletion client/src/main/java/zingg/client/ZinggOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ public enum ZinggOptions {
FIND_TRAINING_DATA("findTrainingData"),
LABEL("label"),
LINK("link"),
GENERATE_DOCS("generateDocs");
GENERATE_DOCS("generateDocs"),
UPDATE_LABEL("updateLabel");

private String value;

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/zingg/Documenter.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void execute() throws ZinggClientException {
LOG.info("Document generation in progress");
Dataset<Row> markedRecords = PipeUtil.read(spark, false, false, PipeUtil.getTrainingDataMarkedPipe(args));
markedRecords = markedRecords.cache();
List<Column> displayCols = DSUtil.getFieldDefColumns(markedRecords, args, false);
//List<Column> displayCols = DSUtil.getFieldDefColumns(markedRecords, args, false);
List<Row> clusterIDs = markedRecords.select(ColName.CLUSTER_COLUMN).distinct().collectAsList();
int totalPairs = clusterIDs.size();
/* Create a data-model */
Expand Down
118 changes: 118 additions & 0 deletions core/src/main/java/zingg/LabelUpdater.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package zingg;

import java.util.List;
import java.util.Scanner;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;

import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.client.pipe.Pipe;
import zingg.client.util.ColName;
import zingg.client.util.ColValues;
import zingg.util.DSUtil;
import zingg.util.LabelMatchType;
import zingg.util.PipeUtil;

public class LabelUpdater extends Labeller {
protected static String name = "zingg.LabelUpdater";
public static final Log LOG = LogFactory.getLog(LabelUpdater.class);

public LabelUpdater() {
setZinggOptions(ZinggOptions.UPDATE_LABEL);
}

public void execute() throws ZinggClientException {
try {
LOG.info("Reading inputs for updateLabelling phase ...");
Dataset<Row> markedRecords = PipeUtil.read(spark, false, false, PipeUtil.getTrainingDataMarkedPipe(args));
processRecordsCli(markedRecords);
LOG.info("Finished updataLabelling phase");
} catch (Exception e) {
e.printStackTrace();
throw new ZinggClientException(e.getMessage());
}
}

public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
LOG.info("Processing Records for CLI updateLabelling");
getMarkedRecordsStat(lines);
printMarkedRecordsStat();
if (lines == null || lines.count() == 0) {
LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.");
return;
}

List<Column> displayCols = DSUtil.getFieldDefColumns(lines, args, false);
try {
int matchFlag;
Dataset<Row> updatedRecords = null;
Dataset<Row> recordsToUpdate = lines;
int selectedOption = -1;
String postMsg;

Scanner sc = new Scanner(System.in);
do {
System.out.print("\n\tPlease enter the cluster id (or 9 to exit): ");
String cluster_id = sc.next();
if (cluster_id.equals("9")) {
LOG.info("User has exit in the middle. Updating the records.");
break;
}
Dataset<Row> currentPair = lines.filter(lines.col(ColName.CLUSTER_COLUMN).equalTo(cluster_id));
if (currentPair.isEmpty()) {
System.out.println("\tInvalid cluster id. Enter '9' to exit");
continue;
}

matchFlag = currentPair.head().getAs(ColName.MATCH_FLAG_COL);
String preMsg = String.format("\n\tThe record pairs belonging to the input cluster id %s are:", cluster_id);
String matchType = LabelMatchType.get(matchFlag).msg;
postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
selectedOption = displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), preMsg, postMsg);
updateLabellerStat(selectedOption, +1);
updateLabellerStat(matchFlag, -1);
printMarkedRecordsStat();
if (selectedOption == 9) {
LOG.info("User has quit in the middle. Updating the records.");
break;
}
recordsToUpdate = recordsToUpdate
.filter(recordsToUpdate.col(ColName.CLUSTER_COLUMN).notEqual(cluster_id));
if (updatedRecords != null) {
updatedRecords = updatedRecords
.filter(updatedRecords.col(ColName.CLUSTER_COLUMN).notEqual(cluster_id));
}
updatedRecords = updateRecords(selectedOption, currentPair, updatedRecords);
} while (selectedOption != 9);

if (updatedRecords != null) {
updatedRecords = updatedRecords.union(recordsToUpdate);
}
writeLabelledOutput(updatedRecords);
sc.close();
LOG.info("Processing finished.");
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
e.printStackTrace();
}
LOG.warn("An error has occured while Updating Label. " + e.getMessage());
throw new ZinggClientException(e.getMessage());
}
return;
}




protected Pipe getOutputPipe() {
Pipe p = PipeUtil.getTrainingDataMarkedPipe(args);
p.setMode(SaveMode.Overwrite);
return p;
}
}
100 changes: 41 additions & 59 deletions core/src/main/java/zingg/Labeller.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,21 @@ public Dataset<Row> getUnmarkedRecords() throws ZinggClientException {
unmarkedRecords = unmarkedRecords.join(markedRecords,
unmarkedRecords.col(ColName.CLUSTER_COLUMN).equalTo(markedRecords.col(ColName.CLUSTER_COLUMN)),
"left_anti");
positivePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_MATCH)).count() / 2;
negativePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_A_MATCH)).count() / 2;
notSurePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_SURE)).count() / 2;
totalCount = markedRecords.count() / 2;
getMarkedRecordsStat(markedRecords);
}
} catch (Exception e) {
LOG.warn("No unmarked record for labelling");
}
return unmarkedRecords;
}

protected void getMarkedRecordsStat(Dataset<Row> markedRecords) {
positivePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_MATCH)).count() / 2;
negativePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_A_MATCH)).count() / 2;
notSurePairsCount = markedRecords.filter(markedRecords.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_SURE)).count() / 2;
totalCount = markedRecords.count() / 2;
}

public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
LOG.info("Processing Records for CLI Labelling");
printMarkedRecordsStat();
Expand Down Expand Up @@ -96,14 +100,15 @@ public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
score = currentPair.head().getAs(ColName.SCORE_COL);
prediction = currentPair.head().getAs(ColName.PREDICTION_COL);

msg1 = String.format("\tRecord pair %d out of %d records to be labelled by the user.\n", index, totalPairs);
msg1 = String.format("\tCurrent labelling round : %d/%d pairs labelled\n", index, totalPairs);
String matchType = LabelMatchType.get(prediction).msg;
msg2 = String.format("\tZingg predicts the records %s with a similarity score of %.2f\n",
msg2 = String.format("\tZingg predicts the above records %s with a similarity score of %.2f",
matchType, score);
String msgHeader = msg1 + msg2;
//String msgHeader = msg1 + msg2;

selected_option = displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), msgHeader);
updateLabellerStat(selected_option);
selected_option = displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), msg1, msg2);
updateLabellerStat(selected_option, 1);
printMarkedRecordsStat();
if (selected_option == 9) {
LOG.info("User has quit in the middle. Updating the records.");
break;
Expand All @@ -123,15 +128,17 @@ public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
}


private int displayRecordsAndGetUserInput(Dataset<Row> records, String preMessage) {
System.out.println();
protected int displayRecordsAndGetUserInput(Dataset<Row> records, String preMessage, String postMessage) {
//System.out.println();
System.out.println(preMessage);
records.show(false);
System.out.println(postMessage);
System.out.println("\tWhat do you think? Your choices are: ");
int selection = readCliInput();
return selection;
}

private Dataset<Row> updateRecords(int matchValue, Dataset<Row> newRecords, Dataset<Row> updatedRecords) {
protected Dataset<Row> updateRecords(int matchValue, Dataset<Row> newRecords, Dataset<Row> updatedRecords) {
newRecords = newRecords.withColumn(ColName.MATCH_FLAG_COL, functions.lit(matchValue));
if (updatedRecords == null) {
updatedRecords = newRecords;
Expand All @@ -142,52 +149,23 @@ private Dataset<Row> updateRecords(int matchValue, Dataset<Row> newRecords, Data
}


private List<String> getDisplayColumns(Dataset<Row> lines) {
List<String> cols = Arrays.asList(lines.columns());
List<String> skipCols = getExcludedColumns();
List<String> displayCols = new ArrayList<>();
for (String key : cols) {
if (!skipCols.contains(key)) {
displayCols.add(key);
}
}
return displayCols;
}

private List<String> getDisplayData(Row row, List<String> cols) {
List<String> strArray = new ArrayList<>();
for (String key : cols) {
strArray.add(row.getAs(key).toString());
}
return strArray;
}

private List<String> getExcludedColumns() {
List<String> columns = new ArrayList<>();
columns.add(ColName.ID_COL);
columns.add(ColName.CLUSTER_COLUMN);
columns.add(ColName.SCORE_COL);
columns.add(ColName.PREDICTION_COL);
columns.add(ColName.MATCH_FLAG_COL);

return columns;
}


int readCliInput() {
Scanner sc = new Scanner(System.in);
System.out.println();
System.out.println("\tPlease select from the following choices");

System.out.println("\tNo, they do not match : 0");
System.out.println("\tYes, they match : 1");
System.out.println("\tNot sure : 2");
System.out.println("");
System.out.println();
System.out.println("\tTo exit : 9");
System.out.println();
System.out.print("\tPlease enter your choice [0,1,2 or 9]: ");

while (!sc.hasNext("[0129]")) {
sc.next();
System.out.println("Nope, enter one of the allowed option!");
System.out.println("Nope, please enter one of the allowed options!");
}
String word = sc.next();
int selection = Integer.parseInt(word);
Expand All @@ -196,36 +174,40 @@ int readCliInput() {
return selection;
}

private void updateLabellerStat(int selected_option) {
protected void updateLabellerStat(int selected_option, int increment) {
totalCount += increment;
if (selected_option == ColValues.MATCH_TYPE_MATCH) {
++positivePairsCount;
++totalCount;
positivePairsCount += increment;
}
else if (selected_option == ColValues.MATCH_TYPE_NOT_A_MATCH) {
++negativePairsCount;
++totalCount;
negativePairsCount += increment;
}
else if (selected_option == ColValues.MATCH_TYPE_NOT_SURE) {
++notSurePairsCount;
++totalCount;
notSurePairsCount += increment;
}
printMarkedRecordsStat();
}

private void printMarkedRecordsStat() {
protected void printMarkedRecordsStat() {
String msg = String.format(
"\tLabelled Pairs : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount,
"\tLabelled pairs so far : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", positivePairsCount, totalCount,
negativePairsCount, totalCount, notSurePairsCount, totalCount);

System.out.println();
System.out.println();
System.out.println();
System.out.println(msg);
}

void writeLabelledOutput(Dataset<Row> records) {
protected void writeLabelledOutput(Dataset<Row> records) {
if (records == null) {
LOG.warn("No records to be labelled.");
return;
}
Pipe p = PipeUtil.getTrainingDataMarkedPipe(args);
PipeUtil.write(records, args, ctx, p);
}
PipeUtil.write(records, args, ctx, getOutputPipe());
}

protected Pipe getOutputPipe() {
return PipeUtil.getTrainingDataMarkedPipe(args);
}
}

Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/zingg/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public void execute() throws ZinggClientException {
tra = tra.cache();
positives = tra.filter(tra.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_MATCH));
negatives = tra.filter(tra.col(ColName.MATCH_FLAG_COL).equalTo(ColValues.MATCH_TYPE_NOT_A_MATCH));
LOG.warn("Training on positive pairs - " + positives.count());
LOG.warn("Training on negative pairs - " + negatives.count());

Dataset<Row> testData = PipeUtil.read(spark, true, args.getNumPartitions(), false, args.getData());
Tree<Canopy> blockingTree = BlockingTreeUtil.createBlockingTreeFromSample(testData, positives, 0.5,
Expand Down
1 change: 1 addition & 0 deletions core/src/main/java/zingg/ZFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public ZFactory() {}
zinggers.put(ZinggOptions.TRAIN_MATCH, TrainMatcher.name);
zinggers.put(ZinggOptions.LINK, Linker.name);
zinggers.put(ZinggOptions.GENERATE_DOCS, Documenter.name);
zinggers.put(ZinggOptions.UPDATE_LABEL, LabelUpdater.name);
}

public IZingg get(ZinggOptions z) throws InstantiationException, IllegalAccessException, ClassNotFoundException {
Expand Down

0 comments on commit ccc90bd

Please sign in to comment.