Skip to content

Commit

Permalink
python arbitrary phases start running
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalgoyal committed Jul 14, 2022
1 parent 1bbe7c2 commit fd1229f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 8 deletions.
7 changes: 6 additions & 1 deletion client/src/main/java/zingg/client/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,16 @@ public class Arguments implements Serializable {
boolean showConcise = false;
float stopWordsCutoff = 0.1f;
long blockSize = 100L;

private String confFile;
private static final String ENV_VAR_MARKER_START = "$";
private static final String ENV_VAR_MARKER_END = "$";
private static final String ESC = "\\";
private static final String PATTERN_ENV_VAR = ESC + ENV_VAR_MARKER_START + "(.+?)" + ESC + ENV_VAR_MARKER_END;

public String getConfFile() {
return confFile;
}

public double getThreshold() {
return threshold;
}
Expand Down Expand Up @@ -169,6 +173,7 @@ public static final Arguments createArgumentsFromJSON(String filePath, String ph
Arguments args = mapper.readValue(new File(filePath), Arguments.class);
LOG.warn("phase is " + phase);
checkValid(args, phase);
args.confFile = filePath;
return args;
} catch (Exception e) {
//e.printStackTrace();
Expand Down
13 changes: 13 additions & 0 deletions client/src/main/java/zingg/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import zingg.client.util.Email;
import zingg.client.util.EmailBody;
Expand All @@ -21,6 +22,7 @@ public class Client implements Serializable {
private Arguments arguments;
private IZingg zingg;
private ClientOptions options;
private SparkSession session;

public static final Log LOG = LogFactory.getLog(Client.class);

Expand Down Expand Up @@ -49,6 +51,16 @@ public Client(Arguments args, ClientOptions options) throws ZinggClientException
}
}

public Client(Arguments args, ClientOptions options, SparkSession session) throws ZinggClientException {
this(args, options);
this.session = session;
JavaSparkContext ctx = new JavaSparkContext(session.sparkContext());
JavaSparkContext.jarOfClass(IZingg.class);
}




public void setZingg(Arguments args, ClientOptions options) throws Exception{
JavaSparkContext.jarOfClass(IZinggFactory.class);
IZinggFactory zf = (IZinggFactory) Class.forName("zingg.ZFactory").newInstance();
Expand Down Expand Up @@ -205,6 +217,7 @@ else if (options.get(ClientOptions.CONF).value.endsWith("env")) {

public void init() throws ZinggClientException {
zingg.init(getArguments(), "");
if (session != null) zingg.setSpark(session);
}


Expand Down
3 changes: 3 additions & 0 deletions client/src/main/java/zingg/client/IZingg.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public interface IZingg {

Expand Down Expand Up @@ -32,4 +33,6 @@ public void init(Arguments args, String license)

public Long getUnsureMarkedRecordsStat(Dataset<Row> markedRecords);

public void setSpark(SparkSession session);

}
8 changes: 7 additions & 1 deletion core/src/main/java/zingg/PeekModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ public void execute() throws ZinggClientException {
try {
LOG.info("Generic Python phase starts");

PythonRunner.main(new String[]{"python/phases/assessModel.py", "--conf", "test.json"});
PythonRunner.main(new String[]{"python/phases/assessModel.py",
"pyFiles",
"--phase",
"peekModel",
"--conf",
args.getConfFile()
});

LOG.info("Generic Python phase ends");
} catch (Exception e) {
Expand Down
5 changes: 3 additions & 2 deletions python/phases/assessModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
LOG = logging.getLogger("zingg.assessModel")

def main():
LOG.info("Phase AssessModel starts")
LOG.info("Phase AssessModel starts")
print("arguments are ", sys.argv[0:])

#excluding argv[0] that is nothing but the current executable file
options = ClientOptions(sys.argv[1:])
options.setPhase("peekModel")
arguments = Arguments.createArgumentsFromJSON(options.getConf(), options.getPhase())
client = Zingg(arguments, options)
client = ZinggWithSpark(arguments, options)
client.init()

pMarkedDF = client.getPandasDfFromDs(client.getMarkedRecords())
Expand Down
28 changes: 24 additions & 4 deletions python/zingg/zingg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Zingg:
def __init__(self, args, options):
self.client = jvm.zingg.client.Client(args.getArgs(), options.getClientOptions())


def init(self):
""" Method to initialize zingg client by reading internal configurations and functions """
self.client.init()
Expand Down Expand Up @@ -158,6 +159,21 @@ def getPandasDfFromDs(self, data):
return pd.DataFrame(df.collect(), columns=df.columns)


class ZinggWithSpark(Zingg):

""" This class is the main point of interface with the Zingg matching product. Construct a client to Zingg using provided arguments and spark master. If running locally, set the master to local.
:param args: arguments for training and matching
:type args: Arguments
:param options: client option for this class object
:type options: ClientOptions
"""

def __init__(self, args, options):
self.client = jvm.zingg.client.Client(args.getArgs(), options.getClientOptions(), spark._jsparkSession)


class Arguments:
""" This class helps supply match arguments to Zingg. There are 3 basic steps in any match process.
Expand Down Expand Up @@ -295,10 +311,14 @@ class ClientOptions:
""":LOCATION: location parameter for this class"""

def __init__(self, args = None):
if(args!=None):
self.co = jvm.zingg.client.ClientOptions(args)
else:
self.co = jvm.zingg.client.ClientOptions(["--phase", "trainMatch", "--conf", "dummy", "--license", "dummy", "--email", "xxx@yyy.com"])
if(args == None):
args = []
args.append(self.LICENSE)
args.append("zinggLic.txt")
args.append(self.EMAIL)
args.append("zingg@zingg.ai")
print("arguments for client options are ", args)
self.co = jvm.zingg.client.ClientOptions(args)

def getClientOptions(self):
""" Method to get pointer address of this class
Expand Down

0 comments on commit fd1229f

Please sign in to comment.