From 8060482c2681716980339445e7b7c8c93165a6ef Mon Sep 17 00:00:00 2001 From: Tobin Baker Date: Thu, 9 Nov 2017 17:44:03 -0800 Subject: [PATCH] add string support to Python UDFs --- python/MyriaPythonWorker.py | 27 ++++++++++++++++++- .../escience/myria/MyriaConstants.java | 4 ++- .../evaluate/PythonUDFEvaluator.java | 14 +++++++--- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/python/MyriaPythonWorker.py b/python/MyriaPythonWorker.py index aeec09d3c..4073af3bd 100644 --- a/python/MyriaPythonWorker.py +++ b/python/MyriaPythonWorker.py @@ -15,12 +15,13 @@ class SpecialLengths(object): class DataType(object): + EXCEPTION = -1 INT = 1 LONG = 2 FLOAT = 3 DOUBLE = 4 BLOB = 5 - EXCEPTION = 6 + STRING = 6 class Serializer(object): @@ -53,6 +54,16 @@ def read_int(stream): raise EOFError return struct.unpack("!i", obj)[0] + @staticmethod + def read_string(stream): + # this conforms to the DataOutput.writeUTF() specification: + # https://docs.oracle.com/javase/8/docs/api/java/io/DataOutput.html#writeUTF-java.lang.String- + strlen = struct.unpack("!H", stream.read(2))[0] + obj = stream.read(strlen) + if not obj: + raise EOFError + return obj.decode('utf-8') + @staticmethod def write_int(value, stream): stream.write(struct.pack("!i", value)) @@ -69,6 +80,14 @@ def write_double(value, stream): def write_long(value, stream): stream.write(struct.pack("!q", value)) + @staticmethod + def write_string(value, stream): + # this conforms to the DataInput.readUTF() specification: + # https://docs.oracle.com/javase/8/docs/api/java/io/DataInput.html#readUTF-- + bytestr = value.encode('utf-8') + stream.write(struct.pack("!H", len(bytestr))) + stream.write(bytestr) + class PickleSerializer(Serializer): @@ -83,8 +102,12 @@ def read_item(cls, stream, item_type, length): obj = cls.read_float(stream) elif item_type == DataType.DOUBLE: obj = cls.read_double(stream) + elif item_type == DataType.STRING: + obj = cls.read_string(stream) elif item_type == DataType.BLOB: obj = cls.loads(stream.read(length)) + else: + raise ValueError("Unknown item type %d" % item_type) return obj @classmethod @@ -128,6 +151,8 @@ def write_with_length(cls, obj, stream, output_type): assert type(obj) is str cls.write_int(len(obj), stream) stream.write(obj) + else: + raise ValueError("Unknown output type %d" % output_type) @classmethod def read_command(cls, stream): diff --git a/src/edu/washington/escience/myria/MyriaConstants.java b/src/edu/washington/escience/myria/MyriaConstants.java index b74f54d01..050d1963a 100644 --- a/src/edu/washington/escience/myria/MyriaConstants.java +++ b/src/edu/washington/escience/myria/MyriaConstants.java @@ -369,11 +369,13 @@ public int getVal() { * Python type enum. */ public static enum PythonType { + EXCEPTION(-1), INT(1), LONG(2), FLOAT(3), DOUBLE(4), - BLOB(5); + BLOB(5), + STRING(6); private int val; PythonType(final int val) { diff --git a/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java b/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java index 81a86e9c8..af1cdd0c7 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java @@ -202,8 +202,6 @@ public void evalGroups(final MutableTupleBuffer state, final int col) throws DbE /** * @param count number of tuples returned. * @param result writable column - * @param result2 appendable table - * @param resultColIdx id of the result column. * @throws DbException in case of error. */ public void readFromStream(final WritableColumn count, final WritableColumn result) @@ -241,6 +239,8 @@ public void readFromStream(final WritableColumn count, final WritableColumn resu result.appendInt(dIn.readInt()); } else if (type == MyriaConstants.PythonType.LONG.getVal()) { result.appendLong(dIn.readLong()); + } else if (type == MyriaConstants.PythonType.STRING.getVal()) { + result.appendString(dIn.readUTF()); } else if (type == MyriaConstants.PythonType.BLOB.getVal()) { int l = dIn.readInt(); if (l > 0) { @@ -297,7 +297,15 @@ private void writeToStream(@Nonnull final ReadableTable tb, final int row, final dOut.writeLong(tb.getLong(columnIdx, row)); break; case STRING_TYPE: - LOGGER.debug("STRING type is not yet supported for python function "); + String str = tb.getString(columnIdx, row); + if (str != null && str.length() > 0) { + int byteSize = (str.getBytes(StandardCharsets.UTF_8)).length; + dOut.writeInt(MyriaConstants.PythonType.LONG.getVal()); + dOut.writeInt(byteSize + 2); // size in bytes + 2 bytes length prefix + dOut.writeUTF(str); + } else { + dOut.writeInt(MyriaConstants.PythonSpecialLengths.NULL_LENGTH.getVal()); + } break; case DATETIME_TYPE: LOGGER.debug("date time not yet supported for python function ");