Skip to content

Commit

Permalink
support for float,int, double, long and bytes primitives in python UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
parmitam committed Jun 24, 2016
1 parent 4d0bdef commit 9c5d9e9
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 98 deletions.
74 changes: 37 additions & 37 deletions python/MyriaPythonWorker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,43 @@ def read_int(stream):
def write_int(value, stream):
stream.write(struct.pack("!i",value))

def write_with_length(obj, stream):
write_int(len(obj),stream)
stream.write(obj)
def write_float(value, stream):
stream.write(struct.pack("!f",value))

def write_double(value, stream):
stream.write(struct.pack("!d",value))

class PickleSerializer(object):
def write_long(value, stream):
stream.write(struct.pack("!q",value))


def write_with_length(obj, stream, outputType):
print("in write_with_length")
print("Output type: "+ str(outputType))

def dump_stream(self,iterator,stream):
for obj in iterator:
self._write_with_length(obj,stream)
if(outputType ==DataType.INT):
print("trying to send back an int")
write_int(DataType.INT, stream)
write_int(obj,stream)
elif(outputType == DataType.LONG):
write_int(DataType.LONG,stream)
write_long(stream.write(obj))
elif(outputType == DataType.FLOAT):
write_int(DataType.FLOAT,stream)
write_float(stream.write(obj))
elif(outputType == DataType.DOUBLE ):
write_int(DataType.DOUBLE,stream)
write_double(stream.write(obj))
elif(outputType == DataType.BYTES):
write_int(DataType.BYTES,stream)
write_int(len(obj),stream)
stream.write(obj)


def load_stream(self,stream, size):
while True:
try:
(yield self._read_with_length(stream, size))
except EOFError:
return


class PickleSerializer(object):

def write_with_length(self, obj, stream):
serialized = self.dumps(obj)

Expand All @@ -81,37 +98,20 @@ def write_with_length(self, obj, stream):
write_int(len(serialized), stream)
stream.write(serialized)


def read_with_length(self,stream):

length = read_int(stream)
if length < 0:
print("this is a command!")
if (length == SpecialLengths.NULL):
print("got a null value")
return 0
else:
return None
obj = stream.read(length)
if len(obj) < length:
raise EOFError

return obj

def write_tuple(self, stream, tuptype, tuplesize):
if(len(tuptype)!=tuplesize):
raise ValueError("type list is not the same as tuple size")

def read_item(self, stream,elementType,length):
if(elementType ==DataType.INT):
def read_item(self, stream, itemType,length):
if(itemType ==DataType.INT):
obj = read_int(stream)
elif(elementType == DataType.LONG):
elif(itemType == DataType.LONG):
obj = read_long(stream)
elif(elementType == DataType.FLOAT):
elif(itemType == DataType.FLOAT):
obj = read_float(stream)
elif(elementType == DataType.DOUBLE ):
elif(itemType == DataType.DOUBLE ):
obj = read_double(stream)
elif(elementType == DataType.BYTES):
elif(itemType == DataType.BYTES):
obj = self.loads(stream.read(length))

return obj
Expand Down Expand Up @@ -143,7 +143,7 @@ def read_tuple(self, stream, tuplesize):



def _read_command(self,stream):
def read_command(self,stream):
length = read_int(stream)

if length < 0:
Expand Down
12 changes: 8 additions & 4 deletions python/MyriaPythonWorker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ def main(infile, outfile):
try:
#get code

func = pickleSer._read_command(infile)
func = pickleSer.read_command(infile)
print("read command!")
tuplesize = read_int(infile)
print("read tuple size")
outputType = read_int(infile)
print("read output type")



Expand All @@ -30,15 +32,17 @@ def main(infile, outfile):
print("python process trying to read tuple")
tup =pickleSer.read_tuple(infile,tuplesize)

print("python process done reading tuple, now writing ")
pickleSer.write_with_length(func(tup),outfile)
print("python process done reading tuple, now writing ")
retval = func(tup)
write_with_length(retval, outfile, outputType)
#pickleSer.write_with_length(func(tup),outfile)
outfile.flush()


except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN,outfile)
write_with_length(traceback.format_exc().encode("utf-8"),outfile)
write_with_length(traceback.format_exc().encode("utf-8"),outfile,5)
print(traceback.format_exc(), file=sys.stderr)
except Exception:
print("python process failed with exception: ", file=sys.stderr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public Type getOutputType(final ExpressionOperatorParameter parameters) {

@Override
public String getJavaString(final ExpressionOperatorParameter parameters) {
LOGGER.info("looking for expression value for parameter");

switch (valueType) {
case BOOLEAN_TYPE:
case DOUBLE_TYPE:
Expand Down
18 changes: 12 additions & 6 deletions src/edu/washington/escience/myria/expression/PyUDFExpression.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class PyUDFExpression extends BinaryExpression {
/** The name of the python function. */
@JsonProperty
private final String name;
@JsonProperty
private final Type outputType;

// private int leftColumnIdx;
// private int rightColumnIdx;
Expand All @@ -31,22 +33,22 @@ public class PyUDFExpression extends BinaryExpression {
@SuppressWarnings("unused")
private PyUDFExpression() {
name = "";
outputType = Type.BYTES_TYPE;
}

public PyUDFExpression(final ExpressionOperator left, final ExpressionOperator right, final String name) {
public PyUDFExpression(final ExpressionOperator left, final ExpressionOperator right, final String name,
final Type outputType) {
super(left, right);
this.name = name;

LOGGER.info("left string :" + left.toString());
LOGGER.info("right string :" + right.toString());
// setColumnId(left, right);
this.outputType = outputType;

}

@Override
public Type getOutputType(final ExpressionOperatorParameter parameters) {
// look at the output schema of from the expressionOperatorParameter?
return Type.BYTES_TYPE;

return outputType;

}

Expand All @@ -57,6 +59,10 @@ public String getName() {
return name;
}

public Type getOutput() {
return outputType;
}

// don't need the getJavaSubstring
@Override
public String getJavaString(final ExpressionOperatorParameter parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class PythonUDFEvaluator extends GenericEvaluator {
private boolean bLeftState = false;
private boolean bRightState = false;

private final Type outputType;

/**
* Default constructor.
*
Expand All @@ -65,6 +67,7 @@ public PythonUDFEvaluator(final Expression expression, final ExpressionOperatorP

// parameters and expression are saved in the super
LOGGER.info("Output name for the python expression" + getExpression().getOutputName());

if (pyFuncReg != null) {
pyFunction = pyFuncReg;
} else {
Expand All @@ -74,12 +77,13 @@ public PythonUDFEvaluator(final Expression expression, final ExpressionOperatorP
needsState = true;
}
PyUDFExpression op = (PyUDFExpression) expression.getRootExpressionOperator();
outputType = op.getOutput();

ExpressionOperator left = op.getLeft();
LOGGER.info("left string :" + left.toString());
// LOGGER.info("left string :" + left.toString());

ExpressionOperator right = op.getRight();
LOGGER.info("right string :" + right.toString());
// LOGGER.info("right string :" + right.toString());

if (left.getClass().equals(VariableExpression.class)) {
leftColumnIdx = ((VariableExpression) left).getColumnIdx();
Expand All @@ -106,12 +110,18 @@ private void initEvaluator() throws DbException {
ExpressionOperator op = getExpression().getRootExpressionOperator();

String pyFunc = ((PyUDFExpression) op).getName();

LOGGER.info(pyFunc);

try {

String pyCodeString = pyFunction.getUDF(pyFunc);
LOGGER.info("length of string from postgres " + pyCodeString.length());
pyWorker.sendCodePickle(pyCodeString, 2);// tuple size is always 2 for binary expression.
if (pyCodeString == null) {
LOGGER.info("no python UDF with name {} registered.", pyFunc);
throw new DbException("No Python UDf with given name registered.");
} else {
pyWorker.sendCodePickle(pyCodeString, 2, outputType);// tuple size is always 2 for binary expression.
}

} catch (Exception e) {
LOGGER.info(e.getMessage());
Expand All @@ -134,7 +144,35 @@ public void compile() {
public void eval(final ReadableTable tb, final int rowIdx, final WritableColumn result, final ReadableTable state)
throws DbException {
Object obj = evaluatePython(tb, rowIdx, state);
result.appendByteBuffer(ByteBuffer.wrap((byte[]) obj));
if (obj == null) {
throw new DbException("python process returned null!");
}
try {
switch (outputType) {
case DOUBLE_TYPE:
result.appendDouble((Double) obj);
break;
case BYTES_TYPE:
result.appendByteBuffer(ByteBuffer.wrap((byte[]) obj));
break;
case FLOAT_TYPE:
result.appendFloat((float) obj);
break;
case INT_TYPE:
result.appendInt((int) obj);
break;
case LONG_TYPE:
result.appendLong((long) obj);
break;
default:
LOGGER.info("type not supported as python Output");
break;

}
} catch (Exception e) {
throw new DbException(e);
}

}

public void evalUpdatePyExpression(final ReadableTable tb, final int rowIdx, final AppendableTable result,
Expand All @@ -149,8 +187,32 @@ public void evalUpdatePyExpression(final ReadableTable tb, final int rowIdx, fin
resultcol = rightColumnIdx;
}
LOGGER.info("trying to update state on column: " + resultcol);
try {
switch (outputType) {
case DOUBLE_TYPE:
result.putDouble(resultcol, (Double) obj);
break;
case BYTES_TYPE:
result.putByteBuffer(resultcol, (ByteBuffer.wrap((byte[]) obj)));
break;
case FLOAT_TYPE:
result.putFloat(resultcol, (float) obj);
break;
case INT_TYPE:
result.putInt(resultcol, (int) obj);
break;
case LONG_TYPE:
result.putLong(resultcol, (long) obj);
break;

default:
LOGGER.info("type not supported as python Output");
break;

result.putByteBuffer(resultcol, (ByteBuffer.wrap((byte[]) obj)));
}
} catch (Exception e) {
throw new DbException(e);
}

}

Expand Down Expand Up @@ -210,30 +272,41 @@ public Column<?> evaluateColumn(final TupleBatch tb) throws DbException {

private Object readFromStream() throws DbException {
LOGGER.info("trying to read now");
int length = 0;
byte[] obj = null;
int type = 0;
Object obj = null;
DataInputStream dIn = pyWorker.getDataInputStream();

try {
length = dIn.readInt(); // read length of incoming message
switch (length) {
case PYTHON_EXCEPTION:
int excepLength = dIn.readInt();
byte[] excp = new byte[excepLength];
dIn.readFully(excp);
throw new DbException(new String(excp));

default:
if (length > 0) {
type = dIn.readInt(); // read length of incoming message
if (type == PYTHON_EXCEPTION) {
int excepLength = dIn.readInt();
byte[] excp = new byte[excepLength];
dIn.readFully(excp);
throw new DbException(new String(excp));
} else {
LOGGER.info("type read: " + type);
if (type == MyriaConstants.PythonType.DOUBLE.getVal()) {
obj = dIn.readDouble();
} else if (type == MyriaConstants.PythonType.FLOAT.getVal()) {
obj = dIn.readFloat();
} else if (type == MyriaConstants.PythonType.INT.getVal()) {
LOGGER.info("trying to read int ");
obj = dIn.readInt();
} else if (type == MyriaConstants.PythonType.LONG.getVal()) {
obj = dIn.readLong();
} else if (type == MyriaConstants.PythonType.BYTES.getVal()) {
int l = dIn.readInt();
if (l > 0) {
LOGGER.info("length greater than zero!");
obj = new byte[length];
dIn.readFully(obj);
obj = new byte[l];
dIn.readFully((byte[]) obj);
}
break;
}

}

} catch (Exception e) {
LOGGER.info("Error reading int from stream");
LOGGER.info("Error reading from stream");
throw new DbException(e);
}
return obj;
Expand Down Expand Up @@ -305,11 +378,4 @@ private void writeToStream(final ReadableTable tb, final int row, final int colu

}

/**
* @param from
* @param row
* @param stateTuple
* @param stateTuple2
*/

}

0 comments on commit 9c5d9e9

Please sign in to comment.