diff --git a/jsonQueries/globalJoin_jwang/ingest_smallTable.json b/jsonQueries/globalJoin_jwang/ingest_smallTable.json index 5caf8baab..24d1f52d0 100644 --- a/jsonQueries/globalJoin_jwang/ingest_smallTable.json +++ b/jsonQueries/globalJoin_jwang/ingest_smallTable.json @@ -12,7 +12,7 @@ "dataType" : "Bytes", "bytes" : "MSA0NAoyIDUxCjQ2IDE3CjYzIDM0CjU0IDYzCjIwIDk0CjEyIDY2Cjc5IDQyCjEgMTAKODggMjAKMTAgNDIKNTYgNDQKMTAgMTIKNzkgMzcKMzAgNjYKODMgMTMKMzEgMQozMSA5OQo4MSAzNQo3MCAyNgo0IDUxCjE1IDY2Cjg4IDY2CjI3IDE3CjMxIDgyCjc2IDc0Cjk2IDY1CjYyIDIyCjkwIDU5CjEzIDI5CjQ0IDQyCjM1IDYyCjk5IDE1Cjk1IDc3CjEwIDcwCjI0IDMwCjgyIDY0CjQ0IDQ4CjY1IDc0CjE4IDg1CjQ5IDE0Cjc1IDk5CjU3IDk1CjQyIDk2CjQxIDY5CjE0IDY1CjE2IDExCjcyIDIyCjc2IDgyCjY2IDY4Cjc0IDg4CjQ3IDYKNTYgMAo2IDkKNTAgODAKNiAzMQo3NiA0NAo0OSAzMAo0NyAxNgo4MiA3NwoxIDgxCjIwIDQwCjE4IDU2CjI4IDkyCjU4IDE2CjgyIDEzCjcxIDc1CjYwIDQxCjIzIDkKMiA1MQo4NiA5NQo4IDgxCjk3IDc5CjE4IDQxCjg5IDQ4CjU5IDUxCjIxIDg2CjYzIDc2CjQyIDIyCjczIDM4CjI0IDE3CjggMzQKNzggMTUKOTMgMTUKMzEgMjIKNzMgMjkKOTMgMTYKODcgOTUKNSA1Nwo0MiA4OAoxNSA4NwozOCA5NwowIDc2CjU3IDUxCjMwIDE5CjUyIDI4CjQyIDE0CjczIDI4CjM3IDY5CjQzIDQ3Cg==" }, - "partitionFunction": { + "distributeFunction": { "type": "Hash", "indexes": [0] }, diff --git a/jsonQueries/pythonUDF/ingest_blob.json b/jsonQueries/pythonUDF/ingest_blob.json index 333dcc4d7..cb6197b88 100644 --- a/jsonQueries/pythonUDF/ingest_blob.json +++ b/jsonQueries/pythonUDF/ingest_blob.json @@ -8,10 +8,9 @@ "columnTypes" : ["LONG_TYPE", "LONG_TYPE","LONG_TYPE","STRING_TYPE"], "columnNames" : ["id", "subjid","imgid" ,"image"] }, - "s3Source" : { - "dataType" : "S3", - "s3Uri" : "s3://imagedb-data/dmridatasample.csv" + "source" : { + "dataType" : "URI", + "uri" : "https://s3-us-west-2.amazonaws.com/imagedb-data/dmridatasample.csv" }, - "delimiter": ",", - "workers": [1,2] + "delimiter": "," } diff --git a/jsonQueries/pythonUDF/udfAgg.json b/jsonQueries/pythonUDF/udfAgg.json index f57a3bf04..d19f84b71 100644 --- a/jsonQueries/pythonUDF/udfAgg.json +++ b/jsonQueries/pythonUDF/udfAgg.json @@ -10,7 +10,7 @@ "opId":3 }, { - "opType":"MultiGroupByAggregate", + "opType":"Aggregate", "argGroupFields":[1,2], "aggregators":[ { diff --git a/jsonQueries/pythonUDF/udfAggSingleColumn.json b/jsonQueries/pythonUDF/udfAggSingleColumn.json index 5cfaf25b2..38f729102 100644 --- a/jsonQueries/pythonUDF/udfAggSingleColumn.json +++ b/jsonQueries/pythonUDF/udfAggSingleColumn.json @@ -90,8 +90,8 @@ } ], "argChild":2, - "argGroupField":1, - "opType":"SingleGroupByAggregate", + "argGroupFields":[1], + "opType":"Aggregate", "opId":3 }, { diff --git a/python/MyriaPythonWorker/serializers.py b/python/MyriaPythonWorker/serializers.py index cf725da7a..066ad40d2 100644 --- a/python/MyriaPythonWorker/serializers.py +++ b/python/MyriaPythonWorker/serializers.py @@ -102,15 +102,15 @@ def write_with_length(self, obj, stream): def read_item(self, stream, itemType, length): obj = None - if(itemType == DataType.INT): + if itemType == DataType.INT: obj = read_int(stream) - elif(itemType == DataType.LONG): + elif itemType == DataType.LONG: obj = read_long(stream) - elif(itemType == DataType.FLOAT): + elif itemType == DataType.FLOAT: obj = read_float(stream) - elif(itemType == DataType.DOUBLE): + elif itemType == DataType.DOUBLE: obj = read_double(stream) - elif(itemType == DataType.BLOB): + elif itemType == DataType.BLOB: obj = self.loads(stream.read(length)) return obj @@ -122,10 +122,10 @@ def read_tuple(self, stream, tuplesize): # Second read the length length = read_int(stream) - if (length == SpecialLengths.NULL): + if length == SpecialLengths.NULL or length == 0: datalist.append(0) - # length is >0, read the item now - elif (length > 0): + # length is > 0, read the item now + elif length > 0: obj = self.read_item(stream, elementType, length) datalist.append(obj) diff --git a/python/README.md b/python/README.md index 232bb2e41..0a11b10cb 100644 --- a/python/README.md +++ b/python/README.md @@ -1,4 +1,4 @@ #Myria Python Worker. Online documentation for [Myria](http://myria.cs.washington.edu/) -Myria Python worker is used for executing python UDFs. +Myria Python worker is used for executing python UDFs. diff --git a/src/edu/washington/escience/myria/api/DatasetResource.java b/src/edu/washington/escience/myria/api/DatasetResource.java index f9a7034b5..cbe735967 100644 --- a/src/edu/washington/escience/myria/api/DatasetResource.java +++ b/src/edu/washington/escience/myria/api/DatasetResource.java @@ -104,11 +104,11 @@ public Response getDataset( @PathParam("programName") final String programName, @PathParam("relationName") final String relationName) throws DbException { - DatasetStatus status = - server.getDatasetStatus(RelationKey.of(userName, programName, relationName)); + RelationKey relationKey = RelationKey.of(userName, programName, relationName); + DatasetStatus status = server.getDatasetStatus(relationKey); if (status == null) { /* Not found, throw a 404 (Not Found) */ - throw new MyriaApiException(Status.NOT_FOUND, "That dataset was not found"); + throw new MyriaApiException(Status.NOT_FOUND, "Dataset " + relationKey + " was not found"); } status.setUri(getCanonicalResourcePath(uriInfo, status.getRelationKey())); /* Yay, worked! */ @@ -356,13 +356,12 @@ public Response deleteDataset( @PathParam("programName") final String programName, @PathParam("relationName") final String relationName) throws DbException { - DatasetStatus status = - server.getDatasetStatus(RelationKey.of(userName, programName, relationName)); + RelationKey relationKey = RelationKey.of(userName, programName, relationName); + DatasetStatus status = server.getDatasetStatus(relationKey); if (status == null) { /* Dataset not found, throw a 404 (Not Found) */ - throw new MyriaApiException(Status.NOT_FOUND, "That dataset was not found"); + throw new MyriaApiException(Status.NOT_FOUND, "Dataset " + relationKey + " was not found"); } - RelationKey relationKey = status.getRelationKey(); // delete command try { server.deleteDataset(relationKey); @@ -566,7 +565,7 @@ private Response doIngest( /* Check overwriting existing dataset. */ try { if (!MoreObjects.firstNonNull(overwrite, false) && server.getSchema(relationKey) != null) { - throw new MyriaApiException(Status.CONFLICT, "That dataset already exists."); + throw new MyriaApiException(Status.CONFLICT, "Dataset " + relationKey + " already exists."); } } catch (CatalogException e) { throw new DbException(e); @@ -642,7 +641,8 @@ public Response addDatasetToCatalog(final DatasetEncoding dataset, @Context fina if (!MoreObjects.firstNonNull(dataset.overwrite, Boolean.FALSE) && server.getSchema(dataset.relationKey) != null) { /* Found, throw a 409 (Conflict) */ - throw new MyriaApiException(Status.CONFLICT, "That dataset already exists."); + throw new MyriaApiException( + Status.CONFLICT, "Dataset " + dataset.relationKey + " already exists."); } } catch (CatalogException e) { throw new DbException(e); diff --git a/src/edu/washington/escience/myria/api/FunctionResource.java b/src/edu/washington/escience/myria/api/FunctionResource.java index 39ab82b67..cab86ca9e 100644 --- a/src/edu/washington/escience/myria/api/FunctionResource.java +++ b/src/edu/washington/escience/myria/api/FunctionResource.java @@ -36,7 +36,6 @@ */ /** * This is the class that handles API calls to create or fetch functions. - * */ @Consumes(MediaType.APPLICATION_JSON) @Produces(MyriaApiConstants.JSON_UTF_8) @@ -51,7 +50,6 @@ public class FunctionResource { protected static final org.slf4j.Logger LOGGER = LoggerFactory.getLogger(FunctionResource.class); /** - * * @return a list of function, names only. * @throws DbException if there is an error accessing the Catalog. */ @@ -74,7 +72,7 @@ public Response createFunction(final CreateFunctionEncoding encoding) throws DbE encoding.binary, encoding.workers); } catch (Exception e) { - throw new DbException(); + throw new DbException(e); } /* Build the response to return the queryId */ ResponseBuilder response = Response.ok(); @@ -82,7 +80,7 @@ public Response createFunction(final CreateFunctionEncoding encoding) throws DbE } /** - * @param name function name + * @param name function name * @return details of a registered function. * @throws DbException if there is an error accessing the Catalog. */ diff --git a/src/edu/washington/escience/myria/api/encoding/AggregateEncoding.java b/src/edu/washington/escience/myria/api/encoding/AggregateEncoding.java index a0782d680..035f606dc 100644 --- a/src/edu/washington/escience/myria/api/encoding/AggregateEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/AggregateEncoding.java @@ -6,11 +6,11 @@ /** JSON wrapper for Aggregate. */ public class AggregateEncoding extends UnaryOperatorEncoding { - /** aggregators. */ + @Required public int[] argGroupFields; @Required public AggregatorFactory[] aggregators; @Override public Aggregate construct(ConstructArgs args) { - return new Aggregate(null, aggregators); + return new Aggregate(null, argGroupFields, aggregators); } } diff --git a/src/edu/washington/escience/myria/api/encoding/MultiGroupByAggregateEncoding.java b/src/edu/washington/escience/myria/api/encoding/MultiGroupByAggregateEncoding.java deleted file mode 100644 index 7c4eac14a..000000000 --- a/src/edu/washington/escience/myria/api/encoding/MultiGroupByAggregateEncoding.java +++ /dev/null @@ -1,16 +0,0 @@ -package edu.washington.escience.myria.api.encoding; - -import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; -import edu.washington.escience.myria.operator.agg.AggregatorFactory; -import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; - -public class MultiGroupByAggregateEncoding extends UnaryOperatorEncoding { - - @Required public int[] argGroupFields; - @Required public AggregatorFactory[] aggregators; - - @Override - public MultiGroupByAggregate construct(ConstructArgs args) { - return new MultiGroupByAggregate(null, argGroupFields, aggregators); - } -} diff --git a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java index 2ecfdd2ff..5a2d6cd65 100644 --- a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java @@ -51,7 +51,6 @@ @Type(name = "LocalMultiwayProducer", value = LocalMultiwayProducerEncoding.class), @Type(name = "Merge", value = MergeEncoding.class), @Type(name = "MergeJoin", value = MergeJoinEncoding.class), - @Type(name = "MultiGroupByAggregate", value = MultiGroupByAggregateEncoding.class), @Type(name = "NChiladaFileScan", value = NChiladaFileScanEncoding.class), @Type(name = "RightHashCountingJoin", value = RightHashCountingJoinEncoding.class), @Type(name = "RightHashJoin", value = RightHashJoinEncoding.class), @@ -62,7 +61,6 @@ @Type(name = "SetGlobal", value = SetGlobalEncoding.class), @Type(name = "ShuffleConsumer", value = GenericShuffleConsumerEncoding.class), @Type(name = "ShuffleProducer", value = GenericShuffleProducerEncoding.class), - @Type(name = "SingleGroupByAggregate", value = SingleGroupByAggregateEncoding.class), @Type(name = "Singleton", value = SingletonEncoding.class), @Type(name = "StatefulApply", value = StatefulApplyEncoding.class), @Type(name = "SymmetricHashJoin", value = SymmetricHashJoinEncoding.class), diff --git a/src/edu/washington/escience/myria/api/encoding/QueryConstruct.java b/src/edu/washington/escience/myria/api/encoding/QueryConstruct.java index c8e0b2329..43ef9c040 100644 --- a/src/edu/washington/escience/myria/api/encoding/QueryConstruct.java +++ b/src/edu/washington/escience/myria/api/encoding/QueryConstruct.java @@ -39,9 +39,9 @@ import edu.washington.escience.myria.operator.Operator; import edu.washington.escience.myria.operator.RootOperator; import edu.washington.escience.myria.operator.UpdateCatalog; -import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; +import edu.washington.escience.myria.operator.agg.Aggregate; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; -import edu.washington.escience.myria.operator.agg.SingleColumnAggregatorFactory; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregatorFactory; import edu.washington.escience.myria.operator.network.CollectProducer; import edu.washington.escience.myria.operator.network.Consumer; import edu.washington.escience.myria.operator.network.EOSController; @@ -663,9 +663,9 @@ public static SubQuery getRelationTupleUpdateSubQuery( /* Master plan: collect, sum, insert the updates. */ Consumer consumer = new Consumer(schema, collectId, workerPlans.keySet()); - MultiGroupByAggregate aggCounts = - new MultiGroupByAggregate( - consumer, new int[] {0, 1, 2}, new SingleColumnAggregatorFactory(3, AggregationOp.SUM)); + Aggregate aggCounts = + new Aggregate( + consumer, new int[] {0, 1, 2}, new PrimitiveAggregatorFactory(3, AggregationOp.SUM)); UpdateCatalog catalog = new UpdateCatalog(aggCounts, server); SubQueryPlan masterPlan = new SubQueryPlan(catalog); diff --git a/src/edu/washington/escience/myria/api/encoding/SingleGroupByAggregateEncoding.java b/src/edu/washington/escience/myria/api/encoding/SingleGroupByAggregateEncoding.java deleted file mode 100644 index 38241360a..000000000 --- a/src/edu/washington/escience/myria/api/encoding/SingleGroupByAggregateEncoding.java +++ /dev/null @@ -1,16 +0,0 @@ -package edu.washington.escience.myria.api.encoding; - -import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; -import edu.washington.escience.myria.operator.agg.AggregatorFactory; -import edu.washington.escience.myria.operator.agg.SingleGroupByAggregate; - -public class SingleGroupByAggregateEncoding extends UnaryOperatorEncoding { - - @Required public AggregatorFactory[] aggregators; - @Required public int argGroupField; - - @Override - public SingleGroupByAggregate construct(ConstructArgs args) { - return new SingleGroupByAggregate(null, argGroupField, aggregators); - } -} diff --git a/src/edu/washington/escience/myria/api/encoding/SymmetricHashJoinEncoding.java b/src/edu/washington/escience/myria/api/encoding/SymmetricHashJoinEncoding.java index 3707342f2..e1cccacea 100644 --- a/src/edu/washington/escience/myria/api/encoding/SymmetricHashJoinEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/SymmetricHashJoinEncoding.java @@ -8,29 +8,27 @@ public class SymmetricHashJoinEncoding extends BinaryOperatorEncoding { - public List argColumnNames; @Required public int[] argColumns1; @Required public int[] argColumns2; @Required public int[] argSelect1; @Required public int[] argSelect2; + public List argColumnNames; public boolean argSetSemanticsLeft = false; public boolean argSetSemanticsRight = false; public JoinPullOrder argOrder = JoinPullOrder.ALTERNATE; @Override public SymmetricHashJoin construct(final ConstructArgs args) { - SymmetricHashJoin join = - new SymmetricHashJoin( - argColumnNames, - null, - null, - argColumns1, - argColumns2, - argSelect1, - argSelect2, - argSetSemanticsLeft, - argSetSemanticsRight); - join.setPullOrder(argOrder); - return join; + return new SymmetricHashJoin( + null, + null, + argColumns1, + argColumns2, + argSelect1, + argSelect2, + argSetSemanticsLeft, + argSetSemanticsRight, + argColumnNames, + argOrder); } } diff --git a/src/edu/washington/escience/myria/column/builder/BlobColumnBuilder.java b/src/edu/washington/escience/myria/column/builder/BlobColumnBuilder.java index 485779d89..09e6dd56f 100644 --- a/src/edu/washington/escience/myria/column/builder/BlobColumnBuilder.java +++ b/src/edu/washington/escience/myria/column/builder/BlobColumnBuilder.java @@ -18,15 +18,15 @@ import edu.washington.escience.myria.proto.DataProto.ColumnMessage; import edu.washington.escience.myria.storage.TupleUtils; import edu.washington.escience.myria.util.MyriaUtils; + /** * A column of Blob values. - * */ public final class BlobColumnBuilder extends ColumnBuilder { /** * The internal representation of the data. - * */ + */ private final ByteBuffer[] data; /** Number of elements in this column. */ @@ -34,7 +34,7 @@ public final class BlobColumnBuilder extends ColumnBuilder { /** * If the builder has built the column. - * */ + */ private boolean built = false; /** Constructs an empty column that can hold up to TupleBatch.BATCH_SIZE elements. */ @@ -48,21 +48,16 @@ public BlobColumnBuilder() { * * @param numDates the actual num strings in the data * @param data the underlying data - * */ + */ private BlobColumnBuilder(final ByteBuffer[] data, final int numBB) { this.numBB = numBB; this.data = data; } - /* - * Constructs a BlobColumn by deserializing the given ColumnMessage. - * + /* Constructs a BlobColumn by deserializing the given ColumnMessage. * @param message a ColumnMessage containing the contents of this column. - * * @param numTuples num tuples in the column message - * - * @return the built column - */ + * @return the built column */ public static BlobColumn buildFromProtobuf(final ColumnMessage message, final int numTuples) { Preconditions.checkArgument( message.getType().ordinal() == ColumnMessage.Type.BLOB_VALUE, @@ -86,9 +81,12 @@ public static BlobColumn buildFromProtobuf(final ColumnMessage message, final in } @Override - public BlobColumnBuilder appendBlob(final ByteBuffer value) throws BufferOverflowException { + public BlobColumnBuilder appendBlob(ByteBuffer value) throws BufferOverflowException { Preconditions.checkState( !built, "No further changes are allowed after the builder has built the column."); + if (value == null) { + value = ByteBuffer.allocate(0); + } Objects.requireNonNull(value, "value"); if (numBB >= TupleUtils.getBatchSize(Type.BLOB_TYPE)) { throw new BufferOverflowException(); diff --git a/src/edu/washington/escience/myria/column/builder/ColumnBuilder.java b/src/edu/washington/escience/myria/column/builder/ColumnBuilder.java index ea10f8317..0fabcdc05 100644 --- a/src/edu/washington/escience/myria/column/builder/ColumnBuilder.java +++ b/src/edu/washington/escience/myria/column/builder/ColumnBuilder.java @@ -1,6 +1,7 @@ package edu.washington.escience.myria.column.builder; import java.nio.BufferOverflowException; +import java.nio.ByteBuffer; import java.sql.ResultSet; import java.sql.SQLException; @@ -11,6 +12,7 @@ import com.almworks.sqlite4java.SQLiteException; import com.almworks.sqlite4java.SQLiteStatement; +import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.column.mutable.MutableColumn; import edu.washington.escience.myria.storage.ReadableColumn; @@ -18,7 +20,6 @@ /** * @param type of the objects in this column. - * */ public abstract class ColumnBuilder> implements ReadableColumn, WritableColumn, ReplaceableColumn { @@ -193,6 +194,31 @@ public void replaceString(@Nonnull final String value, final int row) { throw new UnsupportedOperationException(getClass().getName()); } + /** + * @param type the type of the column to be returned. + * @return a new empty column of the specified type. + */ + public static ColumnBuilder of(final Type type) { + switch (type) { + case BOOLEAN_TYPE: + return new BooleanColumnBuilder(); + case DATETIME_TYPE: + return new DateTimeColumnBuilder(); + case DOUBLE_TYPE: + return new DoubleColumnBuilder(); + case FLOAT_TYPE: + return new FloatColumnBuilder(); + case INT_TYPE: + return new IntColumnBuilder(); + case LONG_TYPE: + return new LongColumnBuilder(); + case STRING_TYPE: + return new StringColumnBuilder(); + default: + throw new IllegalArgumentException("Type " + type + " is invalid"); + } + } + @Override public void replaceBlob(@Nonnull final ByteBuffer value, final int row) { throw new UnsupportedOperationException(getClass().getName()); diff --git a/src/edu/washington/escience/myria/column/builder/StringColumnBuilder.java b/src/edu/washington/escience/myria/column/builder/StringColumnBuilder.java index 7b181068c..abf5796f5 100644 --- a/src/edu/washington/escience/myria/column/builder/StringColumnBuilder.java +++ b/src/edu/washington/escience/myria/column/builder/StringColumnBuilder.java @@ -16,26 +16,24 @@ import edu.washington.escience.myria.column.mutable.StringMutableColumn; import edu.washington.escience.myria.proto.DataProto.ColumnMessage; import edu.washington.escience.myria.proto.DataProto.StringColumnMessage; -import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleUtils; import edu.washington.escience.myria.util.MyriaUtils; /** * A column of String values. - * */ public final class StringColumnBuilder extends ColumnBuilder { /** * The internal representation of the data. - * */ + */ private final String[] data; /** Number of elements in this column. */ private int numStrings; /** * If the builder has built the column. - * */ + */ private boolean built = false; /** Constructs an empty column that can hold up to TupleBatch.BATCH_SIZE elements. */ @@ -49,7 +47,7 @@ public StringColumnBuilder() { * * @param numStrings the actual num strings in the data * @param data the underlying data - * */ + */ private StringColumnBuilder(final String[] data, final int numStrings) { this.numStrings = numStrings; this.data = data; @@ -84,7 +82,6 @@ public static StringColumn buildFromProtobuf(final ColumnMessage message, final public StringColumnBuilder appendString(final String value) throws BufferOverflowException { Preconditions.checkState( !built, "No further changes are allowed after the builder has built the column."); - Objects.requireNonNull(value, "value"); if (numStrings >= TupleUtils.getBatchSize(Type.STRING_TYPE)) { throw new BufferOverflowException(); } diff --git a/src/edu/washington/escience/myria/expression/ConstantExpression.java b/src/edu/washington/escience/myria/expression/ConstantExpression.java index cb3e85684..c4d06404e 100644 --- a/src/edu/washington/escience/myria/expression/ConstantExpression.java +++ b/src/edu/washington/escience/myria/expression/ConstantExpression.java @@ -95,6 +95,7 @@ public ConstantExpression(final boolean value) { public ConstantExpression(final String value) { this(Type.STRING_TYPE, value); } + /** * Construct Blob constant. * diff --git a/src/edu/washington/escience/myria/expression/Expression.java b/src/edu/washington/escience/myria/expression/Expression.java index f7e0f8a07..9bc467637 100644 --- a/src/edu/washington/escience/myria/expression/Expression.java +++ b/src/edu/washington/escience/myria/expression/Expression.java @@ -33,26 +33,20 @@ public class Expression implements Serializable { */ @JsonProperty private final ExpressionOperator rootExpressionOperator; - /** - * Variable name of result. - */ + /** Variable name of input tuple batch. */ + public static final String INPUT = "input"; + /** Variable name of row index of input. */ + public static final String INPUTROW = "inputRow"; + /** Variable name of state. */ + public static final String STATE = "state"; + /** Variable name of row index of state. */ + public static final String STATEROW = "stateRow"; + /** Variable name of result. */ public static final String RESULT = "result"; - /** - * Variable name of result count. - */ + /** Variable name of result count. */ public static final String COUNT = "count"; - /** - * Variable name of input tuple batch. - */ - public static final String TB = "tb"; - /** - * Variable name of row index. - */ - public static final String ROW = "row"; - /** - * Variable name of state. - */ - public static final String STATE = "state"; + /** Variable name of column offset of state. */ + public static final String STATECOLOFFSET = "stateColOffset"; /** * This is not really unused, it's used automagically by Jackson deserialization. @@ -192,7 +186,7 @@ public boolean isConstant() { * * @return if this expression contains a python UDF. */ - public boolean isRegisteredUDF() { + public boolean isRegisteredPythonUDF() { return hasOperator(PyUDFExpression.class); } diff --git a/src/edu/washington/escience/myria/expression/ExpressionOperator.java b/src/edu/washington/escience/myria/expression/ExpressionOperator.java index 316895bd9..c6d1fd43a 100644 --- a/src/edu/washington/escience/myria/expression/ExpressionOperator.java +++ b/src/edu/washington/escience/myria/expression/ExpressionOperator.java @@ -6,6 +6,8 @@ import java.io.Serializable; import java.util.List; +import javax.annotation.Nonnull; + import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonSubTypes.Type; import com.fasterxml.jackson.annotation.JsonTypeInfo; @@ -105,5 +107,6 @@ public boolean hasArrayOutputType() { /** * @return all children */ + @Nonnull public abstract List getChildren(); } diff --git a/src/edu/washington/escience/myria/expression/NAryExpression.java b/src/edu/washington/escience/myria/expression/NAryExpression.java index 3ceccec6f..f877156ad 100644 --- a/src/edu/washington/escience/myria/expression/NAryExpression.java +++ b/src/edu/washington/escience/myria/expression/NAryExpression.java @@ -15,9 +15,7 @@ import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; /** - * * @author dominik - * */ public abstract class NAryExpression extends ExpressionOperator { @@ -33,7 +31,7 @@ public abstract class NAryExpression extends ExpressionOperator { * This is not really unused, it's used automagically by Jackson deserialization. */ protected NAryExpression() { - children = null; + children = ImmutableList.of(); } /** diff --git a/src/edu/washington/escience/myria/expression/StateExpression.java b/src/edu/washington/escience/myria/expression/StateExpression.java index 73ec452c4..0f4c0756b 100644 --- a/src/edu/washington/escience/myria/expression/StateExpression.java +++ b/src/edu/washington/escience/myria/expression/StateExpression.java @@ -48,7 +48,9 @@ public String getJavaString(final ExpressionOperatorParameter parameters) { .append(getOutputType(parameters).getName()) .append("(") .append(getColumnIdx()) - .append(", 0)") + .append(",") + .append(Expression.STATEROW) + .append(")") .toString(); } diff --git a/src/edu/washington/escience/myria/expression/VariableExpression.java b/src/edu/washington/escience/myria/expression/VariableExpression.java index 8c71ddd84..f455c1f79 100644 --- a/src/edu/washington/escience/myria/expression/VariableExpression.java +++ b/src/edu/washington/escience/myria/expression/VariableExpression.java @@ -42,13 +42,13 @@ public Type getOutputType(final ExpressionOperatorParameter parameters) { @Override public String getJavaString(final ExpressionOperatorParameter parameters) { // We generate a variable access into the tuple buffer. - return new StringBuilder(Expression.TB) + return new StringBuilder(Expression.INPUT) .append(".get") .append(getOutputType(parameters).getName()) .append("(") .append(columnIdx) .append(", ") - .append(Expression.ROW) + .append(Expression.INPUTROW) .append(")") .toString(); } diff --git a/src/edu/washington/escience/myria/expression/evaluate/BooleanEvaluator.java b/src/edu/washington/escience/myria/expression/evaluate/BooleanEvaluator.java index e1f427e0f..01103eb0f 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/BooleanEvaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/BooleanEvaluator.java @@ -52,7 +52,7 @@ public void compile() throws DbException { se.createFastEvaluator( getJavaExpressionWithAppend(), BooleanEvalInterface.class, - new String[] {Expression.TB, Expression.ROW}); + new String[] {Expression.INPUT, Expression.INPUTROW}); } catch (Exception e) { throw new DbException("Error when compiling expression " + this, e); } diff --git a/src/edu/washington/escience/myria/expression/evaluate/ConstantEvaluator.java b/src/edu/washington/escience/myria/expression/evaluate/ConstantEvaluator.java index f1ae4ea14..66b382162 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/ConstantEvaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/ConstantEvaluator.java @@ -9,8 +9,8 @@ import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.MyriaConstants; +import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.column.ConstantValueColumn; import edu.washington.escience.myria.column.builder.WritableColumn; import edu.washington.escience.myria.expression.Expression; @@ -19,6 +19,7 @@ import edu.washington.escience.myria.operator.StatefulApply; import edu.washington.escience.myria.storage.ReadableTable; import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleUtils; /** * An Expression evaluator for generic expressions that produces a constant such as the initial state in @@ -76,15 +77,6 @@ public ConstantEvaluator( */ private final ExpressionEvaluator evaluator; - /** - * Creates an {@link ExpressionEvaluator} from the {@link #javaExpression}. This does not really compile the - * expression and is thus faster. - */ - @Override - public void compile() { - /* Do nothing! */ - } - /** * Evaluates the {@link #getJavaExpressionWithAppend()} using the {@link #evaluator}. * @@ -96,19 +88,24 @@ public Object eval() { @Override public void eval( - final ReadableTable tb, - final int rowIdx, - final WritableColumn count, + final ReadableTable input, + final int inputRow, + final ReadableTable state, + final int stateRow, final WritableColumn result, - final ReadableTable state) { - throw new UnsupportedOperationException( - "Should not be here. Should be using evaluateColumn() instead"); + final WritableColumn count) { + result.appendObject(value); + count.appendInt(1); } @Override - public EvaluatorResult evaluateColumn(final TupleBatch tb) { - return new EvaluatorResult( - new ConstantValueColumn((Comparable) value, type, tb.numTuples()), - new ConstantValueColumn(1, Type.INT_TYPE, tb.numTuples())); + public EvaluatorResult evalTupleBatch(final TupleBatch tb, final Schema outputSchema) + throws DbException { + if (TupleUtils.getBatchSize(outputSchema) == tb.getBatchSize()) { + return new EvaluatorResult( + new ConstantValueColumn((Comparable) value, type, tb.numTuples()), + new ConstantValueColumn(1, Type.INT_TYPE, tb.numTuples())); + } + return super.evalTupleBatch(tb, outputSchema); } } diff --git a/src/edu/washington/escience/myria/expression/evaluate/Evaluator.java b/src/edu/washington/escience/myria/expression/evaluate/Evaluator.java index 064021465..f02a01d6d 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/Evaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/Evaluator.java @@ -111,16 +111,6 @@ public boolean isCopyFromInput() { return rootOp instanceof VariableExpression; } - /** - * An expression does not have to be compiled when it only renames or copies a column. This is an optimization to - * avoid evaluating the expression and avoid autoboxing values. - * - * @return true if the expression does not have to be compiled. - */ - public boolean needsCompiling() { - return !(isCopyFromInput() || isConstant() || isRegisteredUDF()); - } - /** * @return true if the expression evaluates to a constant */ @@ -134,10 +124,11 @@ public boolean isConstant() { public boolean needsState() { return needsState; } + /** * @return true if the expression is a contains a python UDF expression. */ public boolean isRegisteredUDF() { - return getExpression().isRegisteredUDF(); + return getExpression().isRegisteredPythonUDF(); } } diff --git a/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalAppendInterface.java b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalAppendInterface.java new file mode 100644 index 000000000..11e527c99 --- /dev/null +++ b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalAppendInterface.java @@ -0,0 +1,29 @@ +package edu.washington.escience.myria.expression.evaluate; + +import edu.washington.escience.myria.column.builder.WritableColumn; +import edu.washington.escience.myria.storage.ReadableTable; + +/** + * Interface for evaluating a single {@link edu.washington.escience.myria.expression.Expression} and appending the + * results to a column, along with a count of results. + */ +public interface ExpressionEvalAppendInterface extends ExpressionEvalInterface { + /** + * The interface evaluates a single {@link edu.washington.escience.myria.expression.Expression} and appends the + * results and (optional) counts to the given columns. + * + * @param input the input tuple batch + * @param inputRow row index of the input tuple batch + * @param state optional state that is passed during evaluation + * @param stateRow row index of the state + * @param result a table storing evaluation results + * @param count a column storing the number of results returned from this row + */ + void evaluate( + final ReadableTable input, + final int inputRow, + final ReadableTable state, + final int stateRow, + final WritableColumn result, + final WritableColumn count); +} diff --git a/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalInterface.java b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalInterface.java index f8610a2a0..831d7a825 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalInterface.java +++ b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalInterface.java @@ -1,30 +1,7 @@ package edu.washington.escience.myria.expression.evaluate; -import edu.washington.escience.myria.column.builder.WritableColumn; -import edu.washington.escience.myria.storage.ReadableTable; - /** * Interface for evaluating a single {@link edu.washington.escience.myria.expression.Expression} and appending the * results to a column, along with a count of results. */ -public interface ExpressionEvalInterface { - /** - * The interface evaluating a single {@link edu.washington.escience.myria.expression.Expression} and appending it to a - * column. We only need a reference to the tuple batch and a row id, plus the optional state of e.g. an - * {@link edu.washington.escience.myria.operator.agg.Aggregate} or a - * {@link edu.washington.escience.myria.operator.StatefulApply}. The variables will be fetched from the tuple buffer - * using the rowId provided in {@link edu.washington.escience.myria.expression.VariableExpression}. - * - * @param tb a tuple batch - * @param row index of the row in the tb that should be used - * @param count a column storing the number of results returned from this row - * @param result a table storing evaluation results - * @param state optional state that is passed during evaluation - */ - void evaluate( - final ReadableTable tb, - final int row, - final WritableColumn count, - final WritableColumn result, - final ReadableTable state); -} +public interface ExpressionEvalInterface {} diff --git a/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalReplaceInterface.java b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalReplaceInterface.java new file mode 100644 index 000000000..47b03ad2c --- /dev/null +++ b/src/edu/washington/escience/myria/expression/evaluate/ExpressionEvalReplaceInterface.java @@ -0,0 +1,27 @@ +package edu.washington.escience.myria.expression.evaluate; + +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableTable; + +/** + * Interface for evaluating a single {@link edu.washington.escience.myria.expression.Expression} and appending the + * results to a column, along with a count of results. + */ +public interface ExpressionEvalReplaceInterface extends ExpressionEvalInterface { + /** + * The interface evaluating a single {@link edu.washington.escience.myria.expression.Expression} and replace old + * values in a state column with the results. + * + * @param input the input tuple batch + * @param inputRow row index of the input tuple batch + * @param state optional state that is passed during evaluation + * @param stateRow row index of the state + * @param stateColOffset column offset of the state + */ + void evaluate( + final ReadableTable input, + final int inputRow, + final MutableTupleBuffer state, + final int stateRow, + final int stateColOffset); +} diff --git a/src/edu/washington/escience/myria/expression/evaluate/ExpressionOperatorParameter.java b/src/edu/washington/escience/myria/expression/evaluate/ExpressionOperatorParameter.java index 4b414a6d3..b1dd9d5ea 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/ExpressionOperatorParameter.java +++ b/src/edu/washington/escience/myria/expression/evaluate/ExpressionOperatorParameter.java @@ -1,6 +1,7 @@ package edu.washington.escience.myria.expression.evaluate; import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.functions.PythonFunctionRegistrar; /** * Object that carries parameters down the expression tree. @@ -11,7 +12,9 @@ public class ExpressionOperatorParameter { /** The schema of the state. */ private final Schema stateSchema; /** The id of the worker that is running the expression. */ - private final Integer workerID; + private Integer workerID = null; + /** Python function registrar. */ + private PythonFunctionRegistrar pyFuncReg = null; /** * Simple constructor. @@ -19,7 +22,6 @@ public class ExpressionOperatorParameter { public ExpressionOperatorParameter() { schema = null; stateSchema = null; - workerID = null; } /** @@ -28,7 +30,6 @@ public ExpressionOperatorParameter() { public ExpressionOperatorParameter(final Schema schema) { this.schema = schema; stateSchema = null; - workerID = null; } /** @@ -38,7 +39,18 @@ public ExpressionOperatorParameter(final Schema schema) { public ExpressionOperatorParameter(final Schema schema, final Schema stateSchema) { this.schema = schema; this.stateSchema = stateSchema; - workerID = null; + } + + /** + * @param schema the input schema + * @param stateSchema the state schema + * @param pyFuncReg Python function registrar + */ + public ExpressionOperatorParameter( + final Schema schema, final Schema stateSchema, final PythonFunctionRegistrar pyFuncReg) { + this.schema = schema; + this.stateSchema = stateSchema; + this.pyFuncReg = pyFuncReg; } /** @@ -63,6 +75,23 @@ public ExpressionOperatorParameter( this.workerID = workerID; } + /** + * @param schema the input schema + * @param stateSchema the schema of the state + * @param workerID id of the worker that is running the expression + * @param pyFuncReg Python function registrar + */ + public ExpressionOperatorParameter( + final Schema schema, + final Schema stateSchema, + final int workerID, + final PythonFunctionRegistrar pyFuncReg) { + this.schema = schema; + this.stateSchema = stateSchema; + this.workerID = workerID; + this.pyFuncReg = pyFuncReg; + } + /** * @return the input schema */ @@ -83,4 +112,11 @@ public Schema getStateSchema() { public int getWorkerId() { return workerID; } + + /** + * @return the Python function registrar + */ + public PythonFunctionRegistrar getPythonFunctionRegistrar() { + return pyFuncReg; + } } diff --git a/src/edu/washington/escience/myria/expression/evaluate/GenericEvaluator.java b/src/edu/washington/escience/myria/expression/evaluate/GenericEvaluator.java index 404a8402f..6c2ce8c3a 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/GenericEvaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/GenericEvaluator.java @@ -1,7 +1,5 @@ package edu.washington.escience.myria.expression.evaluate; -import java.io.IOException; -import java.lang.reflect.InvocationTargetException; import java.util.List; import javax.annotation.Nonnull; @@ -27,10 +25,12 @@ import edu.washington.escience.myria.expression.ExpressionOperator; import edu.washington.escience.myria.expression.VariableExpression; import edu.washington.escience.myria.operator.Apply; +import edu.washington.escience.myria.storage.MutableTupleBuffer; import edu.washington.escience.myria.storage.ReadableColumn; import edu.washington.escience.myria.storage.ReadableTable; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBuffer; +import edu.washington.escience.myria.storage.TupleUtils; /** * An Expression evaluator for generic expressions. Used in {@link Apply}. @@ -41,10 +41,10 @@ public class GenericEvaluator extends Evaluator { private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(GenericEvaluator.class); - /** - * Expression evaluator. - */ + /** Expression evaluator. */ private ExpressionEvalInterface evaluator; + /** The script. */ + private String script; /** * Default constructor. @@ -55,6 +55,20 @@ public class GenericEvaluator extends Evaluator { public GenericEvaluator( final Expression expression, final ExpressionOperatorParameter parameters) { super(expression, parameters); + this.script = getExpression().getJavaExpressionWithAppend(getParameters()); + } + + /** + * @param expression + * @param script + * @param parameters + */ + public GenericEvaluator( + final Expression expression, + final String script, + final ExpressionOperatorParameter parameters) { + super(expression, parameters); + this.script = script; } /** @@ -64,11 +78,6 @@ public GenericEvaluator( */ @Override public void compile() throws DbException { - Preconditions.checkArgument( - needsCompiling() || (getStateSchema() != null), - "This expression does not need to be compiled."); - - String javaExpression = getJavaExpressionWithAppend(); IScriptEvaluator se; try { se = CompilerFactoryFactory.getDefaultCompilerFactory().newScriptEvaluator(); @@ -76,25 +85,39 @@ public void compile() throws DbException { LOGGER.error("Could not create expression evaluator", e); throw new DbException("Could not create expression evaluator", e); } - se.setDefaultImports(MyriaConstants.DEFAULT_JANINO_IMPORTS); - try { - evaluator = - (ExpressionEvalInterface) - se.createFastEvaluator( - javaExpression, - ExpressionEvalInterface.class, - new String[] { - Expression.TB, - Expression.ROW, - Expression.COUNT, - Expression.RESULT, - Expression.STATE - }); + if (script.contains("append")) { + evaluator = + (ExpressionEvalAppendInterface) + se.createFastEvaluator( + script, + ExpressionEvalAppendInterface.class, + new String[] { + Expression.INPUT, + Expression.INPUTROW, + Expression.STATE, + Expression.STATEROW, + Expression.RESULT, + Expression.COUNT + }); + } else { + evaluator = + (ExpressionEvalReplaceInterface) + se.createFastEvaluator( + script, + ExpressionEvalReplaceInterface.class, + new String[] { + Expression.INPUT, + Expression.INPUTROW, + Expression.STATE, + Expression.STATEROW, + Expression.STATECOLOFFSET + }); + } } catch (CompileException e) { - LOGGER.error("Error when compiling expression {}: {}", javaExpression, e); - throw new DbException("Error when compiling expression: " + javaExpression, e); + LOGGER.error("Error when compiling expression {}: {}", script, e); + throw new DbException("Error when compiling expression: " + script, e); } } @@ -102,42 +125,58 @@ public void compile() throws DbException { * Evaluates the {@link #getJavaExpressionWithAppend()} using the {@link #evaluator}. Prefer to use * {@link #evaluateColumn(TupleBatch)} since it can evaluate an entire TupleBatch at a time for better locality. * - * @param tb a tuple batch - * @param rowIdx index of the row that should be used for input data - * @param count column storing number of results (null for single-valued expressions) - * @param result the table storing the result + * @param input a tuple batch + * @param inputRow index of the row that should be used for input data + * @param state additional state that affects the computation + * @param stateRow index of the row that should be used for state + * @param stateColOffset the column offset of the state + * @throws DbException in case of error. + */ + public void updateState( + @Nonnull final ReadableTable input, + final int inputRow, + @Nonnull final MutableTupleBuffer state, + final int stateRow, + final int stateColOffset) + throws DbException { + ((ExpressionEvalReplaceInterface) evaluator) + .evaluate(input, inputRow, state, stateRow, stateColOffset); + } + + /** + * Evaluates the {@link #getJavaExpressionWithAppend()} using the {@link #evaluator}. Prefer to use + * {@link #evaluateColumn(TupleBatch)} since it can evaluate an entire TupleBatch at a time for better locality. + * + * @param input a tuple batch + * @param inputRow index of the row that should be used for input data * @param state additional state that affects the computation - * @throws InvocationTargetException exception thrown from janino. + * @param stateRow index of the row that should be used for state + * @param result the table storing the result + * @param count column storing number of results (null for single-valued expressions) * @throws DbException in case of error. */ public void eval( - @Nonnull final ReadableTable tb, - final int rowIdx, - @Nullable final WritableColumn count, + @Nullable final ReadableTable input, + final int inputRow, + @Nullable final ReadableTable state, + final int stateRow, @Nonnull final WritableColumn result, - @Nullable final ReadableTable state) - throws InvocationTargetException, DbException { + @Nullable final WritableColumn count) + throws DbException { Preconditions.checkArgument( evaluator != null, "Call compile first or copy the data if it is the same in the input."); Preconditions.checkArgument( getExpression().isMultiValued() != (count == null), "count must be null for a single-valued expression and non-null for a multivalued expression."); try { - evaluator.evaluate(tb, rowIdx, count, result, state); + ((ExpressionEvalAppendInterface) evaluator) + .evaluate(input, inputRow, state, stateRow, result, count); } catch (Exception e) { - LOGGER.error(getJavaExpressionWithAppend(), e); + LOGGER.error(script, e); throw e; } } - /** - * @return the Java form of this expression. - */ - @Override - public String getJavaExpressionWithAppend() { - return getExpression().getJavaExpressionWithAppend(getParameters()); - } - /** * Holder class for results and result counts from {@link #evaluateColumn}. */ @@ -148,9 +187,8 @@ public static class EvaluatorResult { protected EvaluatorResult( @Nonnull final TupleBuffer results, @Nonnull final Column resultCounts) { - final List resultBatches = results.finalResult(); ImmutableList.Builder> resultColumnsBuilder = ImmutableList.builder(); - for (final TupleBatch tb : resultBatches) { + for (final TupleBatch tb : results.finalResult()) { resultColumnsBuilder.add(tb.getDataColumns().get(0)); } this.resultColumns = resultColumnsBuilder.build(); @@ -188,40 +226,39 @@ public ReadableColumn getResultCounts() { } /** - * Evaluate an expression over an entire TupleBatch and return the column(s) of results, along with a column of result counts from each tuple. This method cannot take state - * into consideration. + * Evaluate an expression over an entire TupleBatch and return the column(s) of results, along with a column of result + * counts from each tuple. This method cannot take state into consideration. * * @param tb the tuples to be input to this expression - * @return an {@link EvaluatorResult} containing the results and result counts of evaluating this expression on the entire TupleBatch - * @throws InvocationTargetException exception thrown from janino + * @param outputSchema the schema that results from this evaluator belongs to, used to determine the tuple batch size + * @return an {@link EvaluatorResult} containing the results and result counts of evaluating this expression on the + * entire TupleBatch * @throws DbException */ - public EvaluatorResult evaluateColumn(final TupleBatch tb) - throws InvocationTargetException, DbException { - // Optimization for result counts of single-valued expressions. + public EvaluatorResult evalTupleBatch(final TupleBatch tb, final Schema outputSchema) + throws DbException { final Column constCounts = new ConstantValueColumn(1, Type.INT_TYPE, tb.numTuples()); - final WritableColumn countsWriter; - if (getExpression().isMultiValued()) { - countsWriter = ColumnFactory.allocateColumn(Type.INT_TYPE); - } else { - // For single-valued expressions, the Java expression will never attempt to write to `countsWriter`. - countsWriter = null; - } - ExpressionOperator op = getExpression().getRootExpressionOperator(); + int batchSize = TupleUtils.getBatchSize(outputSchema); // Critical optimization: return a zero-copy reference to a column referenced by a pure `VariableExpression`. - if (isCopyFromInput()) { + if (isCopyFromInput() && batchSize >= tb.numTuples()) { + ExpressionOperator op = getExpression().getRootExpressionOperator(); return new EvaluatorResult( tb.getDataColumns().get(((VariableExpression) op).getColumnIdx()), constCounts); } - // For multivalued expressions, we may get more than `TupleBatch.BATCH_SIZE` results, - // so we need to pass in a `TupleBuffer` rather than a `ColumnBuilder` to `eval()`, - // and return a `List` rather than a `Column` of results. - final Type type = getOutputType(); + /* For multivalued expressions, we may get more than batchSize results, so we need to pass in a `TupleBuffer` rather + * than a `ColumnBuilder` to `eval()`, and return a `List` rather than a `Column` of results. */ final TupleBuffer resultsBuffer = - new TupleBuffer(Schema.ofFields(getExpression().getOutputName(), type)); + new TupleBuffer( + Schema.ofFields(getExpression().getOutputName(), getOutputType()), batchSize); final WritableColumn resultsWriter = resultsBuffer.asWritableColumn(0); + // For single-valued expressions, the Java expression will never attempt to Usite to `countsWriter`. + WritableColumn countsWriter = null; + if (getExpression().isMultiValued()) { + countsWriter = ColumnFactory.allocateColumn(Type.INT_TYPE); + } for (int rowIdx = 0; rowIdx < tb.numTuples(); ++rowIdx) { - eval(tb, rowIdx, countsWriter, resultsWriter, null); + /* Hack, tb is either Expression.INPUT or Expression.STATE */ + eval(tb, rowIdx, tb, rowIdx, resultsWriter, countsWriter); } final Column resultCounts; if (getExpression().isMultiValued()) { @@ -231,4 +268,11 @@ public EvaluatorResult evaluateColumn(final TupleBatch tb) } return new EvaluatorResult(resultsBuffer, resultCounts); } + + /** + * @return the script + */ + public String getScript() { + return script; + } } diff --git a/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java b/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java index 217156b59..af47d023c 100644 --- a/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java +++ b/src/edu/washington/escience/myria/expression/evaluate/PythonUDFEvaluator.java @@ -3,28 +3,28 @@ */ package edu.washington.escience.myria.expression.evaluate; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import javax.validation.constraints.NotNull; - import java.io.DataInputStream; import java.io.DataOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.codehaus.janino.ExpressionEvaluator; import com.google.common.base.Preconditions; +import com.gs.collections.api.iterator.IntIterator; +import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.MyriaConstants; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.accessmethod.ConnectionInfo; import edu.washington.escience.myria.api.encoding.FunctionStatus; -import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.column.builder.ColumnBuilder; import edu.washington.escience.myria.column.builder.ColumnFactory; import edu.washington.escience.myria.column.builder.WritableColumn; @@ -37,11 +37,10 @@ import edu.washington.escience.myria.functions.PythonWorker; import edu.washington.escience.myria.operator.Apply; import edu.washington.escience.myria.operator.StatefulApply; -import edu.washington.escience.myria.profiling.ProfilingLogger; -import edu.washington.escience.myria.storage.AppendableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.Tuple; import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; /** * An Expression evaluator for Python UDFs. Used in {@link Apply} and {@link StatefulApply}. @@ -51,23 +50,22 @@ public class PythonUDFEvaluator extends GenericEvaluator { /** logger for this class. */ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(PythonUDFEvaluator.class); - /** python function registrar from which to fetch function pickle.*/ + /** python function registrar from which to fetch function pickle. */ private final PythonFunctionRegistrar pyFuncRegistrar; - /** python worker process. */ private PythonWorker pyWorker; /** index of state column. */ - private final boolean[] isStateColumn; - /** tuple size to be sent to the python process, this is equal to the number of children of the expression. */ - private int numColumns = -1; - + private Set stateColumns; + /** column indices of child ops. */ private int[] columnIdxs = null; - /** Output Type of the expression. */ private Type outputType = null; - /** is expression a flatmap? */ private Boolean isMultiValued = false; + /** Tuple buffers for each group key. */ + private IntObjectHashMap buffer; + /** The internal state schema. */ + private Schema stateSchema; /** * Default constructor. @@ -77,64 +75,44 @@ public class PythonUDFEvaluator extends GenericEvaluator { * @param pyFuncReg python function registrar to get the python function. */ public PythonUDFEvaluator( - final Expression expression, - final ExpressionOperatorParameter parameters, - @NotNull final PythonFunctionRegistrar pyFuncReg) { + final Expression expression, final ExpressionOperatorParameter parameters) + throws DbException { super(expression, parameters); - pyFuncRegistrar = pyFuncReg; + pyFuncRegistrar = parameters.getPythonFunctionRegistrar(); + if (pyFuncRegistrar == null) { + throw new RuntimeException("PythonRegistrar should not be null in PythonUDFEvaluator."); + } PyUDFExpression op = (PyUDFExpression) expression.getRootExpressionOperator(); outputType = op.getOutputType(parameters); List childops = op.getChildren(); - numColumns = childops.size(); - columnIdxs = new int[numColumns]; - isStateColumn = new boolean[numColumns]; - - Arrays.fill(isStateColumn, false); - Arrays.fill(columnIdxs, -1); - } - - /** - * Initializes the python evaluator. - * @throws DbException in case of error. - */ - private void initEvaluator() throws DbException { - ExpressionOperator op = getExpression().getRootExpressionOperator(); - String pyFunctionName = ((PyUDFExpression) op).getName(); - - try { - if (pyFuncRegistrar != null) { - FunctionStatus fs = pyFuncRegistrar.getFunctionStatus(pyFunctionName); - if (fs == null) { - throw new DbException("No Python UDf with given name registered."); - } - isMultiValued = fs.getIsMultivalued(); //if the function is multivalued. - pyWorker.sendCodePickle(fs.getBinary(), numColumns, outputType, isMultiValued); - - List childops = op.getChildren(); - if (childops != null) { - - for (int i = 0; i < childops.size(); i++) { - - if (childops.get(i).getClass().equals(StateExpression.class)) { - isStateColumn[i] = true; - columnIdxs[i] = ((StateExpression) childops.get(i)).getColumnIdx(); - - } else if (childops.get(i).getClass().equals(VariableExpression.class)) { - columnIdxs[i] = ((VariableExpression) childops.get(i)).getColumnIdx(); - } else { - throw new DbException( - "Python expression can only have State or Variable expression as child expressions."); - } - } - } + columnIdxs = new int[childops.size()]; + stateColumns = new HashSet(); + List types = new ArrayList(); + for (int i = 0; i < childops.size(); i++) { + if (childops.get(i) instanceof StateExpression) { + stateColumns.add(i); + columnIdxs[i] = ((StateExpression) childops.get(i)).getColumnIdx(); + types.add(((StateExpression) childops.get(i)).getOutputType(parameters)); + } else if (childops.get(i) instanceof VariableExpression) { + columnIdxs[i] = ((VariableExpression) childops.get(i)).getColumnIdx(); + types.add(((VariableExpression) childops.get(i)).getOutputType(parameters)); } else { - throw new DbException("PythonRegistrar should not be null in PythonUDFEvaluator."); + throw new IllegalStateException( + "Python expression can only have State or Variable expression as child expressions."); } + } + stateSchema = new Schema(types); - } catch (Exception e) { - throw new DbException(e); + String pyFunctionName = op.getName(); + FunctionStatus fs = pyFuncRegistrar.getFunctionStatus(pyFunctionName); + if (fs == null) { + throw new DbException("No Python UDf with given name registered."); } + isMultiValued = fs.getIsMultivalued(); + pyWorker = new PythonWorker(); + pyWorker.sendCodePickle(fs.getBinary(), columnIdxs.length, outputType, isMultiValued); + buffer = new IntObjectHashMap(); } /** @@ -148,103 +126,86 @@ public void compile() { @Override public void eval( - @Nonnull final ReadableTable tb, - final int rowIdx, - @Nullable final WritableColumn count, + @Nonnull final ReadableTable input, + final int inputRow, + @Nullable final ReadableTable state, + final int stateRow, @Nonnull final WritableColumn result, - @Nullable final ReadableTable state) + @Nullable final WritableColumn count) throws DbException { - - if (pyWorker == null) { - pyWorker = new PythonWorker(); - initEvaluator(); - } - int resultColIdx = -1; - - try { - DataOutputStream dOut = pyWorker.getDataOutputStream(); - pyWorker.sendNumTuples(1); - for (int i = 0; i < numColumns; i++) { - if (isStateColumn[i]) { - writeToStream(state, rowIdx, columnIdxs[i], dOut); - } else { - writeToStream(tb, rowIdx, columnIdxs[i], dOut); - } + pyWorker.sendNumTuples(1); + for (int i = 0; i < columnIdxs.length; ++i) { + if (stateColumns.contains(i)) { + writeToStream(state, stateRow, columnIdxs[i]); + } else { + writeToStream(input, inputRow, columnIdxs[i]); } - - // read response back - readFromStream(count, result, null, resultColIdx); - - } catch (Exception e) { - throw new DbException(e); } + readFromStream(count, result); } - /** - * - * @param ltb list of tuple batch - * @param result result table. - * @param state state column. - * @throws DbException in case of error - * @throws IOException in case of error. - */ - public void evalBatch( - final List ltb, final AppendableTable result, final ReadableTable state) - throws DbException, IOException { - if (pyWorker == null) { - pyWorker = new PythonWorker(); - initEvaluator(); + + @Override + public void updateState( + @Nonnull final ReadableTable input, + final int inputRow, + @Nonnull final MutableTupleBuffer state, + final int stateRow, + final int stateColoffset) + throws DbException { + if (!buffer.containsKey(stateRow)) { + buffer.put(stateRow, new TupleBatchBuffer(stateSchema)); } - int resultcol = -1; - for (int i = 0; i < numColumns; i++) { - if (isStateColumn[i]) { - resultcol = columnIdxs[i]; + TupleBatchBuffer tb = buffer.get(stateRow); + for (int i = 0; i < columnIdxs.length; ++i) { + if (stateColumns.contains(i)) { + tb.appendFromColumn(i, state.asColumn(columnIdxs[i] + stateColoffset), stateRow); + } else { + tb.appendFromColumn(i, input.asColumn(columnIdxs[i]), inputRow); } - break; } + }; - try { - - DataOutputStream dOut = pyWorker.getDataOutputStream(); - int numTuples = 0; - for (int j = 0; j < ltb.size(); j++) { - numTuples += ltb.get(j).numTuples(); - } - pyWorker.sendNumTuples(numTuples); - for (int tbIdx = 0; tbIdx < ltb.size(); tbIdx++) { - TupleBatch tb = ltb.get(tbIdx); + /** + * @param state state + * @param col column index of the state to be written to. + * @throws DbException in case of error + */ + public void evalGroups(final MutableTupleBuffer state, final int col) throws DbException { + IntIterator iter = buffer.keySet().intIterator(); + while (iter.hasNext()) { + int key = iter.next(); + pyWorker.sendNumTuples(buffer.get(key).numTuples()); + for (TupleBatch tb : buffer.get(key).getAll()) { for (int tup = 0; tup < tb.numTuples(); tup++) { - for (int col = 0; col < numColumns; col++) { - writeToStream(tb, tup, columnIdxs[col], dOut); + for (int i = 0; i < tb.numColumns(); ++i) { + writeToStream(tb, tup, i); } } } - - // read result back - readFromStream(null, null, result, resultcol); - - } catch (Exception e) { - throw new DbException(e); + ColumnBuilder output = ColumnFactory.allocateColumn(outputType); + /* TODO: Leaving the count column to be null for now since since it's not used by Python evaluator for aggregate. + * A better design is to let the Aggregator emit two columns or even multiple columns. */ + readFromStream(null, output); + if (output.size() > 1) { + throw new RuntimeException("PythonUDFEvaluator cannot be multivalued for Aggregate"); + } + for (int i = 0; i < output.size(); ++i) { + state.replace(col, key, output, i); + } } } /** - *@param count number of tuples returned. - *@param result writable column - *@param result2 appendable table - *@param resultColIdx id of the result column. + * @param count number of tuples returned. + * @param result writable column + * @param result2 appendable table + * @param resultColIdx id of the result column. * @return Object output from python process. * @throws DbException in case of error. */ - private Object readFromStream( - final WritableColumn count, - final WritableColumn result, - final AppendableTable result2, - final int resultColIdx) + public void readFromStream(final WritableColumn count, final WritableColumn result) throws DbException { - int type = 0; - Object obj = null; DataInputStream dIn = pyWorker.getDataInputStream(); - int c = 1; // single valued expressions only return 1 tuple. try { // if it is a flat map operation, read number of tuples to be read. @@ -258,8 +219,8 @@ private Object readFromStream( } for (int i = 0; i < c; i++) { - //then read the type of tuple - type = dIn.readInt(); + // then read the type of tuple + int type = dIn.readInt(); // if the 'type' is exception, throw exception if (type == MyriaConstants.PythonSpecialLengths.PYTHON_EXCEPTION.getVal()) { int excepLength = dIn.readInt(); @@ -269,82 +230,47 @@ private Object readFromStream( } else { // read the rest of the tuple if (type == MyriaConstants.PythonType.DOUBLE.getVal()) { - obj = dIn.readDouble(); - if (resultColIdx == -1) { - result.appendDouble((Double) obj); - } else { - result2.putDouble(resultColIdx, (Double) obj); - } + result.appendDouble(dIn.readDouble()); } else if (type == MyriaConstants.PythonType.FLOAT.getVal()) { - obj = dIn.readFloat(); - if (resultColIdx == -1) { - result.appendFloat((float) obj); - } else { - result2.putFloat(resultColIdx, (float) obj); - } - + result.appendFloat(dIn.readFloat()); } else if (type == MyriaConstants.PythonType.INT.getVal()) { - obj = dIn.readInt(); - if (resultColIdx == -1) { - result.appendInt((int) obj); - } else { - result2.putInt(resultColIdx, (int) obj); - } - + result.appendInt(dIn.readInt()); } else if (type == MyriaConstants.PythonType.LONG.getVal()) { - obj = dIn.readLong(); - if (resultColIdx == -1) { - result.appendLong((long) obj); - } else { - result2.putLong(resultColIdx, (long) obj); - } - + result.appendLong(dIn.readLong()); } else if (type == MyriaConstants.PythonType.BLOB.getVal()) { - int l = dIn.readInt(); if (l > 0) { - obj = new byte[l]; - dIn.readFully((byte[]) obj); - if (resultColIdx == -1) { - result.appendBlob(ByteBuffer.wrap((byte[]) obj)); - } else { - result2.putBlob(resultColIdx, ByteBuffer.wrap((byte[]) obj)); - } + byte[] obj = new byte[l]; + dIn.readFully(obj); + result.appendBlob(ByteBuffer.wrap(obj)); } - } else { throw new DbException("Type not supported by python"); } } } - } catch (Exception e) { throw new DbException(e); } - return obj; } /** - *helper function to write to python process. + * helper function to write to python process. + * * @param tb - input tuple buffer. * @param row - row being evaluated. - * @param columnIdx -columnto be written to the py process. - * @param dOut -output stream + * @param columnIdx - column to be written to the py process. * @throws DbException in case of error. */ - private void writeToStream( - final ReadableTable tb, final int row, final int columnIdx, final DataOutputStream dOut) + private void writeToStream(final ReadableTable tb, final int row, final int columnIdx) throws DbException { + DataOutputStream dOut = pyWorker.getDataOutputStream(); Preconditions.checkNotNull(tb, "input tuple cannot be null"); Preconditions.checkNotNull(dOut, "Output stream for python process cannot be null"); - - Schema tbsc = tb.getSchema(); try { - Type type = tbsc.getColumnType(columnIdx); - - switch (type) { + switch (tb.getSchema().getColumnType(columnIdx)) { case BOOLEAN_TYPE: - LOGGER.debug("BOOLEAN type not supported for python function "); + LOGGER.debug("BOOLEAN type not supported for python function"); break; case DOUBLE_TYPE: dOut.writeInt(MyriaConstants.PythonType.DOUBLE.getVal()); @@ -374,20 +300,15 @@ private void writeToStream( break; case BLOB_TYPE: dOut.writeInt(MyriaConstants.PythonType.BLOB.getVal()); - ByteBuffer input = tb.getBlob(columnIdx, row); - if (input != null && input.hasArray()) { - dOut.writeInt(input.array().length); dOut.write(input.array()); } else { - dOut.writeInt(MyriaConstants.PythonSpecialLengths.NULL_LENGTH.getVal()); } } dOut.flush(); - } catch (Exception e) { throw new DbException(e); } diff --git a/src/edu/washington/escience/myria/expression/evaluate/ScriptEvalInterface.java b/src/edu/washington/escience/myria/expression/evaluate/ScriptEvalInterface.java deleted file mode 100644 index c94b6d89f..000000000 --- a/src/edu/washington/escience/myria/expression/evaluate/ScriptEvalInterface.java +++ /dev/null @@ -1,24 +0,0 @@ -package edu.washington.escience.myria.expression.evaluate; - -import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.Tuple; - -/** - * Interface for evaluators that take multiple expressions and may write multiple columns. - */ -public interface ScriptEvalInterface { - /** - * The interface for applying expressions. We only need a reference to the tuple batch and a row id. The variables - * will be fetched from the tuple buffer using the rowId provided in - * {@link edu.washington.escience.myria.expression.VariableExpression} or - * {@link edu.washington.escience.myria.expression.StateExpression}. - * - * @param tb a tuple batch - * @param row index of the row in the tb that should be used. - * @param result where the output should be written. - * @param state state that is passed during evaluation, and written after the new state is computed. - */ - void evaluate( - final ReadableTable tb, final int row, final AppendableTable result, final Tuple state); -} diff --git a/src/edu/washington/escience/myria/functions/PythonFunctionRegistrar.java b/src/edu/washington/escience/myria/functions/PythonFunctionRegistrar.java index 30ceaa3bf..763c6b296 100644 --- a/src/edu/washington/escience/myria/functions/PythonFunctionRegistrar.java +++ b/src/edu/washington/escience/myria/functions/PythonFunctionRegistrar.java @@ -16,7 +16,7 @@ import edu.washington.escience.myria.storage.TupleBatchBuffer; /** - *This class sets and gets python functions on a postgres instance on a worker. + * This class sets and gets python functions on a postgres instance on a worker. */ public class PythonFunctionRegistrar { @@ -28,14 +28,13 @@ public class PythonFunctionRegistrar { private JdbcAccessMethod accessMethod; /** Buffer for UDFs registered. */ private final TupleBatchBuffer pyFunctions; - /** connection information for reconnection if connection is closed.*/ + /** connection information for reconnection if connection is closed. */ private final ConnectionInfo connectionInfo; /** * Default constructor. * * @param connectionInfo connection information - * * @throws DbException if any error occurs */ public PythonFunctionRegistrar(final ConnectionInfo connectionInfo) throws DbException { @@ -52,6 +51,7 @@ public PythonFunctionRegistrar(final ConnectionInfo connectionInfo) throws DbExc pyFunctions = new TupleBatchBuffer(MyriaConstants.PYUDF_SCHEMA); } + /** Helper function to connect for storing and retrieving UDFs. */ private void connect() throws DbException { /* open the database connection */ @@ -63,7 +63,7 @@ private void connect() throws DbException { * Add function to current worker. * * @param name function name - * @param description of function + * @param description of function * @param outputType of function * @param isMultiValued does function return multiple tuples. * @param binary binary function @@ -121,35 +121,27 @@ public FunctionStatus getFunctionStatus(final String pyFunctionName) throws DbEx sb.append(" where function_name='"); sb.append(pyFunctionName); sb.append("'"); - try { - Iterator tuples = - accessMethod.tupleBatchIteratorFromQuery(sb.toString(), MyriaConstants.PYUDF_SCHEMA); - - if (tuples.hasNext()) { - - final TupleBatch tb = tuples.next(); - if (tb.numTuples() > 0) { - - FunctionStatus fs = - new FunctionStatus( - pyFunctionName, - tb.getString(1, 0), - tb.getString(2, 0), - tb.getBoolean(3, 0), - FunctionLanguage.PYTHON, - tb.getString(4, 0)); - - return fs; - } + Iterator tuples = + accessMethod.tupleBatchIteratorFromQuery(sb.toString(), MyriaConstants.PYUDF_SCHEMA); + + if (tuples.hasNext()) { + final TupleBatch tb = tuples.next(); + if (tb.numTuples() > 0) { + FunctionStatus fs = + new FunctionStatus( + pyFunctionName, + tb.getString(1, 0), + tb.getString(2, 0), + tb.getBoolean(3, 0), + FunctionLanguage.PYTHON, + tb.getString(4, 0)); + return fs; } - } catch (Exception e) { - throw new DbException(e); } return null; }; /** - * * @return {@code true} if the current JDBC connection is active. */ public boolean isValid() { diff --git a/src/edu/washington/escience/myria/functions/PythonWorker.java b/src/edu/washington/escience/myria/functions/PythonWorker.java index d81c51cb0..6ee0ef344 100644 --- a/src/edu/washington/escience/myria/functions/PythonWorker.java +++ b/src/edu/washington/escience/myria/functions/PythonWorker.java @@ -28,23 +28,22 @@ public class PythonWorker { /***/ private static final long serialVersionUID = 1L; - /** logger*/ + /** logger */ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(PythonWorker.class); - /** server socket for python worker.*/ + /** server socket for python worker. */ private ServerSocket serverSocket = null; - /** client sock for python worker.*/ + /** client sock for python worker. */ private Socket clientSock = null; - /** python worker process.*/ + /** python worker process. */ private Process worker = null; - /** output stream from python worker.*/ + /** output stream from python worker. */ private DataOutputStream dOut; - /**input stream from python worker.*/ + /** input stream from python worker. */ private DataInputStream dIn; /** - * * @throws DbException */ public PythonWorker() throws DbException { @@ -59,7 +58,6 @@ public PythonWorker() throws DbException { } /** - * * @param pyCodeString - python function string * @param numColumns number fo columns to be written to python process. * @param outputType output type of the python function. @@ -73,7 +71,6 @@ public void sendCodePickle( final Boolean isFlatMap) throws DbException { Preconditions.checkNotNull(pyCodeString); - try { if (pyCodeString.length() > 0 && dOut != null) { byte[] bytes = pyCodeString.getBytes(StandardCharsets.UTF_8); @@ -87,32 +84,30 @@ public void sendCodePickle( dOut.writeInt(0); } dOut.flush(); - } else { throw new DbException("Can't write Python Code to worker!"); } - } catch (Exception e) { + } catch (IOException e) { LOGGER.debug("failed to send python code pickle"); throw new DbException(e); } } + /** - * - * @param numTuples: number of tuples to be sent to python function. + * @param numTuples number of tuples to be sent to python function. * @throws IOException * @throws DbException */ - public void sendNumTuples(final int numTuples) throws IOException, DbException { + public void sendNumTuples(final int numTuples) throws DbException { Preconditions.checkArgument(numTuples > 0, "number of tuples: %s", numTuples); try { dOut.writeInt(numTuples); - } catch (Exception e) { + } catch (IOException e) { throw new DbException(e); } } /** - * * @return dataoutput stream for the python worker. */ public DataOutputStream getDataOutputStream() { @@ -121,7 +116,6 @@ public DataOutputStream getDataOutputStream() { } /** - * * @return dataInputStream for the python worker. */ public DataInputStream getDataInputStream() { @@ -130,7 +124,6 @@ public DataInputStream getDataInputStream() { } /** - * * @throws IOException */ public void close() throws IOException { @@ -148,7 +141,6 @@ public void close() throws IOException { } /** - * * @throws UnknownHostException * @throws IOException */ @@ -157,7 +149,6 @@ private void createServerSocket() throws UnknownHostException, IOException { } /** - * * @throws IOException in case of error. */ private void startPythonWorker() throws IOException { @@ -181,7 +172,6 @@ private void startPythonWorker() throws IOException { } /** - * * @param outputType : output type for python function * @throws IOException in case of error. * @throws DbException in case of error. @@ -209,7 +199,6 @@ private void writeOutputType(final Type outputType) throws IOException, DbExcept } /** - * * @throws IOException in case of error. */ private void setupStreams() throws IOException { @@ -218,16 +207,4 @@ private void setupStreams() throws IOException { dIn = new DataInputStream(clientSock.getInputStream()); } } - /** - * @param eos Send end of stream to cleanly close the python process. - * @throws DbException in case of error. - */ - public void sendEos(final int eos) throws DbException { - try { - dOut.writeInt(eos); - close(); - } catch (Exception e) { - throw new DbException(e); - } - } } diff --git a/src/edu/washington/escience/myria/operator/Apply.java b/src/edu/washington/escience/myria/operator/Apply.java index 8b5469641..7ccb59042 100644 --- a/src/edu/washington/escience/myria/operator/Apply.java +++ b/src/edu/washington/escience/myria/operator/Apply.java @@ -18,6 +18,7 @@ import edu.washington.escience.myria.MyriaConstants; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.evaluate.ConstantEvaluator; import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; @@ -55,6 +56,7 @@ public class Apply extends UnaryOperator { * AddCounter to the returning tuplebatch. */ private Boolean addCounter = false; + /** * @return the {@link #emitExpressions} */ @@ -89,8 +91,7 @@ private boolean onlySingleValuedExpressions() { } /** - * @return number of columns that return more than one value for this - * Apply operator. + * @return number of columns that return more than one value for this Apply operator. */ private int numberOfMultiValuedExpressions() { int i = 0; @@ -101,8 +102,10 @@ private int numberOfMultiValuedExpressions() { } return i; } + /** * Should a counter be added? + * * @return */ private boolean getAddCounter() { @@ -119,7 +122,6 @@ private void setAddCounter(Boolean addCounter) { private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(Apply.class); /** - * * @param child child operator that data is fetched from * @param emitExpressions expression that created the output */ @@ -150,23 +152,31 @@ protected TupleBatch fetchNextReady() throws DbException, InvocationTargetExcept while (!outputBuffer.hasFilledTB()) { TupleBatch inputTuples = getChild().nextReady(); if (inputTuples != null) { - // Evaluate expressions on each column and store counts and results. - List resultCountColumns = new ArrayList<>(); - List resultColumns = new ArrayList<>(); - for (final GenericEvaluator eval : emitEvaluators) { - EvaluatorResult evalResult = eval.evaluateColumn(inputTuples); - resultCountColumns.add(evalResult.getResultCounts()); - resultColumns.add(evalResult.getResults()); - } if (onlySingleValuedExpressions()) { - int[] iteratorIndexes = new int[emitEvaluators.size()]; - Arrays.fill(iteratorIndexes, 0); - for (int rowIdx = 0; rowIdx < inputTuples.numTuples(); ++rowIdx) { - for (int i = 0; i < iteratorIndexes.length; ++i) { - outputBuffer.appendFromColumn(i, resultColumns.get(i), rowIdx); + List>> tbs = new ArrayList>>(); + for (final GenericEvaluator eval : emitEvaluators) { + EvaluatorResult evalResult = eval.evalTupleBatch(inputTuples, getSchema()); + List> cols = evalResult.getResultColumns(); + for (int i = 0; i < cols.size(); ++i) { + if (tbs.size() <= i) { + tbs.add(new ArrayList>()); + } + tbs.get(i).add(cols.get(i)); } } + for (List> tb : tbs) { + outputBuffer.absorb(new TupleBatch(getSchema(), tb), true); + } } else { + // Evaluate expressions on each column and store counts and results. + List resultCountColumns = new ArrayList<>(); + List resultColumns = new ArrayList<>(); + for (final GenericEvaluator eval : emitEvaluators) { + EvaluatorResult evalResult = eval.evalTupleBatch(inputTuples, getSchema()); + resultCountColumns.add(evalResult.getResultCounts()); + resultColumns.add(evalResult.getResults()); + } + // Generate the Cartesian product and append to output buffer. int[] resultCounts = new int[emitEvaluators.size()]; int[] cumResultCounts = new int[emitEvaluators.size()]; @@ -235,7 +245,7 @@ protected TupleBatch fetchNextReady() throws DbException, InvocationTargetExcept * returns false. * * @param upperBounds an immutable array of elements representing the sets we are forming the Cartesian product of, - * where each set is of the form [0, i), where i is an element of {@link upperBounds} + * where each set is of the form [0, i), where i is an element of {@link upperBounds} * @param iteratorIndexes a mutable array of elements representing the current element of the Cartesian product * @return if we have exhausted all elements of the Cartesian product */ @@ -271,20 +281,19 @@ protected void init(final ImmutableMap execEnvVars) throws DbExc List evals = new ArrayList<>(); final ExpressionOperatorParameter parameters = - new ExpressionOperatorParameter(inputSchema, getNodeID()); + new ExpressionOperatorParameter( + inputSchema, null, getNodeID(), getPythonFunctionRegistrar()); for (Expression expr : emitExpressions) { GenericEvaluator evaluator; if (expr.isConstant()) { evaluator = new ConstantEvaluator(expr, parameters); - } else if (expr.isRegisteredUDF()) { - evaluator = new PythonUDFEvaluator(expr, parameters, getPythonFunctionRegistrar()); + } else if (expr.isRegisteredPythonUDF()) { + evaluator = new PythonUDFEvaluator(expr, parameters); } else { evaluator = new GenericEvaluator(expr, parameters); } - if (evaluator.needsCompiling()) { - evaluator.compile(); - } + evaluator.compile(); Preconditions.checkArgument(!evaluator.needsState()); evals.add(evaluator); } diff --git a/src/edu/washington/escience/myria/operator/Filter.java b/src/edu/washington/escience/myria/operator/Filter.java index bf7267d29..509da1779 100644 --- a/src/edu/washington/escience/myria/operator/Filter.java +++ b/src/edu/washington/escience/myria/operator/Filter.java @@ -22,7 +22,7 @@ public final class Filter extends UnaryOperator { private static final long serialVersionUID = 1L; /** * The operator. - * */ + */ private final Expression predicate; /** @@ -70,16 +70,11 @@ protected TupleBatch fetchNextReady() throws DbException { @Override protected void init(final ImmutableMap execEnvVars) throws DbException { Preconditions.checkNotNull(predicate); - Schema inputSchema = getChild().getSchema(); - final ExpressionOperatorParameter parameters = new ExpressionOperatorParameter(inputSchema, getNodeID()); - evaluator = new BooleanEvaluator(predicate, parameters); - if (evaluator.needsCompiling()) { - evaluator.compile(); - } + evaluator.compile(); } @Override diff --git a/src/edu/washington/escience/myria/operator/StatefulApply.java b/src/edu/washington/escience/myria/operator/StatefulApply.java index 4a98889c2..43f0bf1e2 100644 --- a/src/edu/washington/escience/myria/operator/StatefulApply.java +++ b/src/edu/washington/escience/myria/operator/StatefulApply.java @@ -1,7 +1,6 @@ package edu.washington.escience.myria.operator; import java.io.IOException; -import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.List; @@ -99,7 +98,7 @@ private void setUpdateExpressions(final List updaterExpressions) { } @Override - protected TupleBatch fetchNextReady() throws DbException, InvocationTargetException, IOException { + protected TupleBatch fetchNextReady() throws DbException, IOException { Operator child = getChild(); if (child.eoi() || getChild().eos()) { @@ -123,7 +122,7 @@ protected TupleBatch fetchNextReady() throws DbException, InvocationTargetExcept !evaluator.getExpression().isMultiValued(), "A multivalued expression cannot be used in StatefulApply."); if (!evaluator.needsState() || evaluator.isCopyFromInput()) { - output.set(columnIdx, evaluator.evaluateColumn(tb).getResultColumns().get(0)); + output.set(columnIdx, evaluator.evalTupleBatch(tb, getSchema()).getResultColumns().get(0)); } else { needState.add(columnIdx); } @@ -144,14 +143,14 @@ protected TupleBatch fetchNextReady() throws DbException, InvocationTargetExcept updateEvaluators .get(columnIdx) - .eval(tb, rowIdx, null, newState.getColumn(columnIdx), state); + .eval(tb, rowIdx, state, 0, newState.asWritableColumn(columnIdx), null); } state = newState; // apply expression for (int index = 0; index < needState.size(); index++) { final GenericEvaluator evaluator = getEmitEvaluators().get(needState.get(index)); // TODO: optimize the case where the state is copied directly - evaluator.eval(tb, rowIdx, null, columnBuilders.get(index), state); + evaluator.eval(tb, rowIdx, state, 0, columnBuilders.get(index), null); } } @@ -175,21 +174,21 @@ protected void init(final ImmutableMap execEnvVars) throws DbExc // these can only be generic or python expressions. for (Expression expr : getEmitExpressions()) { GenericEvaluator evaluator; - if (expr.isRegisteredUDF()) { + if (expr.isConstant()) { + evaluator = + new ConstantEvaluator(expr, new ExpressionOperatorParameter(inputSchema, getNodeID())); + } else if (expr.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator( expr, - new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID()), - getPythonFunctionRegistrar()); - + new ExpressionOperatorParameter( + inputSchema, getStateSchema(), getNodeID(), getPythonFunctionRegistrar())); } else { evaluator = new GenericEvaluator( expr, new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID())); } - if (evaluator.needsCompiling()) { - evaluator.compile(); - } + evaluator.compile(); evaluators.add(evaluator); } setEmitEvaluators(evaluators); @@ -205,25 +204,23 @@ protected void init(final ImmutableMap execEnvVars) throws DbExc ConstantEvaluator evaluator = new ConstantEvaluator(expr, new ExpressionOperatorParameter(inputSchema, getNodeID())); evaluator.compile(); - state.set(columnIdx, evaluator.eval()); + state.putObject(columnIdx, evaluator.eval()); } // initialize update evaluators -- these can be generic or python evaluators for (Expression expr : updateExpressions) { GenericEvaluator evaluator; - if (expr.isRegisteredUDF()) { + if (expr.isRegisteredPythonUDF()) { evaluator = new PythonUDFEvaluator( expr, - new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID()), - getPythonFunctionRegistrar()); - + new ExpressionOperatorParameter( + inputSchema, getStateSchema(), getNodeID(), getPythonFunctionRegistrar())); } else { evaluator = new GenericEvaluator( expr, new ExpressionOperatorParameter(inputSchema, getStateSchema(), getNodeID())); } - evaluator.compile(); updateEvaluators.add(evaluator); } diff --git a/src/edu/washington/escience/myria/operator/SymmetricHashJoin.java b/src/edu/washington/escience/myria/operator/SymmetricHashJoin.java index 6eab66cf0..94fbb537e 100644 --- a/src/edu/washington/escience/myria/operator/SymmetricHashJoin.java +++ b/src/edu/washington/escience/myria/operator/SymmetricHashJoin.java @@ -7,21 +7,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.gs.collections.api.block.procedure.primitive.IntProcedure; -import com.gs.collections.impl.list.mutable.primitive.IntArrayList; -import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; +import com.gs.collections.api.iterator.IntIterator; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.MyriaConstants; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.parallel.QueryExecutionMode; import edu.washington.escience.myria.storage.MutableTupleBuffer; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; -import edu.washington.escience.myria.storage.TupleUtils; -import edu.washington.escience.myria.util.HashUtils; import edu.washington.escience.myria.util.MyriaArrayUtils; /** @@ -32,145 +27,24 @@ public final class SymmetricHashJoin extends BinaryOperator { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** - * The names of the output columns. - */ + /** The names of the output columns. */ private final ImmutableList outputColumns; - - /** - * The column indices for comparing of child 1. - */ - private final int[] leftCompareIndx; - /** - * The column indices for comparing of child 2. - */ - private final int[] rightCompareIndx; - /** - * A hash table for tuples from child 1. {Hashcode -> List of tuple indices with the same hash code} - */ - private transient IntObjectHashMap leftHashTableIndices; - /** - * A hash table for tuples from child 2. {Hashcode -> List of tuple indices with the same hash code} - */ - private transient IntObjectHashMap rightHashTableIndices; - - /** - * The buffer holding the valid tuples from left. - */ - private transient MutableTupleBuffer hashTable1; - /** - * The buffer holding the valid tuples from right. - */ - private transient MutableTupleBuffer hashTable2; - /** - * The buffer holding the results. - */ - private transient TupleBatchBuffer ans; + /** The column indices for comparing of the left child. */ + private final int[] leftCompareColumns; + /** The column indices for comparing of the right child. */ + private final int[] rightCompareColumns; /** Which columns in the left child are to be output. */ private final int[] leftAnswerColumns; /** Which columns in the right child are to be output. */ private final int[] rightAnswerColumns; - - /** - * Traverse through the list of tuples with the same hash code. - */ - private final class JoinProcedure implements IntProcedure { - - /** serialization id. */ - private static final long serialVersionUID = 1L; - - /** - * Hash table. - */ - private MutableTupleBuffer joinAgainstHashTable; - - /** - * - * */ - private int[] inputCmpColumns; - - /** - * the columns to compare against. - */ - private int[] joinAgainstCmpColumns; - /** - * row index of the tuple. - */ - private int row; - - /** - * input TupleBatch. - */ - private TupleBatch inputTB; - /** - * if the tuple which is comparing against the list of tuples with the same hash code is from left child. - */ - private boolean fromLeft; - - @Override - public void value(final int index) { - if (TupleUtils.tupleEquals( - inputTB, inputCmpColumns, row, joinAgainstHashTable, joinAgainstCmpColumns, index)) { - addToAns(inputTB, row, joinAgainstHashTable, index, fromLeft); - } - } - }; - - /** - * Traverse through the list of tuples with the same hash code. - */ - private final class ReplaceProcedure implements IntProcedure { - - /** serialization id. */ - private static final long serialVersionUID = 1L; - - /** - * Hash table. - */ - private MutableTupleBuffer hashTable; - - /** - * the columns to compare against. - */ - private int[] keyColumns; - /** - * row index of the tuple. - */ - private int row; - - /** - * input TupleBatch. - */ - private TupleBatch inputTB; - - /** if found a replacement. */ - private boolean replaced; - - @Override - public void value(final int index) { - if (TupleUtils.tupleEquals(inputTB, keyColumns, row, hashTable, keyColumns, index)) { - replaced = true; - List> columns = inputTB.getDataColumns(); - for (int j = 0; j < inputTB.numColumns(); ++j) { - hashTable.replace(j, index, columns.get(j), row); - } - } - } - }; - - /** - * Traverse through the list of tuples. - */ - private transient JoinProcedure doJoin; - - /** - * Traverse through the list of tuples and replace old values. - */ - private transient ReplaceProcedure doReplace; - + /** The buffer holding the valid tuples from left. */ + private transient TupleHashTable leftHashTable; + /** The buffer holding the valid tuples from right. */ + private transient TupleHashTable rightHashTable; + /** The buffer holding the results. */ + private transient TupleBatchBuffer ans; /** Whether the last child polled was the left child. */ private boolean pollLeft = false; - /** Join pull order, default: ALTERNATE. */ private JoinPullOrder order = JoinPullOrder.ALTERNATE; @@ -179,76 +53,6 @@ public void value(final int index) { /** if the hash table of the right child should use set semantics. */ private boolean setSemanticsRight = false; - /** - * Construct an EquiJoin operator. It returns all columns from both children when the corresponding columns in - * compareIndx1 and compareIndx2 match. - * - * @param left the left child. - * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @throw IllegalArgumentException if there are duplicated column names from the children. - */ - public SymmetricHashJoin( - final Operator left, - final Operator right, - final int[] compareIndx1, - final int[] compareIndx2) { - this(null, left, right, compareIndx1, compareIndx2); - } - - /** - * Construct an EquiJoin operator. It returns the specified columns from both children when the corresponding columns - * in compareIndx1 and compareIndx2 match. - * - * @param left the left child. - * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @param answerColumns1 the columns of the left child to be returned. Order matters. - * @param answerColumns2 the columns of the right child to be returned. Order matters. - * @throw IllegalArgumentException if there are duplicated column names in outputSchema, or if - * outputSchema does not have the correct number of columns and column types. - */ - public SymmetricHashJoin( - final Operator left, - final Operator right, - final int[] compareIndx1, - final int[] compareIndx2, - final int[] answerColumns1, - final int[] answerColumns2) { - this(null, left, right, compareIndx1, compareIndx2, answerColumns1, answerColumns2); - } - - /** - * Construct an SymmetricHashJoin operator. It returns the specified columns from both children when the corresponding - * columns in compareIndx1 and compareIndx2 match. - * - * @param left the left child. - * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @param answerColumns1 the columns of the left child to be returned. Order matters. - * @param answerColumns2 the columns of the right child to be returned. Order matters. - * @param setSemanticsLeft if the hash table of the left child should use set semantics. - * @param setSemanticsRight if the hash table of the right child should use set semantics. - * @throw IllegalArgumentException if there are duplicated column names in outputSchema, or if - * outputSchema does not have the correct number of columns and column types. - */ - public SymmetricHashJoin( - final Operator left, - final Operator right, - final int[] compareIndx1, - final int[] compareIndx2, - final int[] answerColumns1, - final int[] answerColumns2, - final boolean setSemanticsLeft, - final boolean setSemanticsRight) { - this(null, left, right, compareIndx1, compareIndx2, answerColumns1, answerColumns2); - this.setSemanticsLeft = setSemanticsLeft; - this.setSemanticsRight = setSemanticsRight; - } - /** * Construct an SymmetricHashJoin operator. It returns the specified columns from both children when the corresponding * columns in compareIndx1 and compareIndx2 match. @@ -257,29 +61,33 @@ public SymmetricHashJoin( * copied from the children. * @param left the left child. * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @param answerColumns1 the columns of the left child to be returned. Order matters. - * @param answerColumns2 the columns of the right child to be returned. Order matters. * @param setSemanticsLeft if - * the hash table of the left child should use set semantics. - * @param setSemanticsLeft if the hash table of the left child should use set semantics. - * @param setSemanticsRight if the hash table of the right child should use set semantics. + * @param leftCompareColumns the columns of the left child to be compared with the right. Order matters. + * @param rightCompareColumns the columns of the right child to be compared with the left. Order matters. + * @param leftAnswerColumns the columns of the left child to be returned. Order matters. + * @param rightAnswerColumns the columns of the right child to be returned. Order matters. * @param setSemanticsLeft + * if the hash table of the left child should use set semantics. * @throw IllegalArgumentException if there are duplicated column names in outputColumns, or if * outputColumns does not have the correct number of columns and column types. */ public SymmetricHashJoin( - final List outputColumns, final Operator left, final Operator right, - final int[] compareIndx1, - final int[] compareIndx2, - final int[] answerColumns1, - final int[] answerColumns2, - final boolean setSemanticsLeft, - final boolean setSemanticsRight) { - this(outputColumns, left, right, compareIndx1, compareIndx2, answerColumns1, answerColumns2); - this.setSemanticsLeft = setSemanticsLeft; - this.setSemanticsRight = setSemanticsRight; + final int[] leftCompareColumns, + final int[] rightCompareColumns, + final int[] leftAnswerColumns, + final int[] rightAnswerColumns) { + /* Only used by tests */ + this( + left, + right, + leftCompareColumns, + rightCompareColumns, + leftAnswerColumns, + rightAnswerColumns, + false, + false, + null, + JoinPullOrder.ALTERNATE); } /** @@ -290,26 +98,32 @@ public SymmetricHashJoin( * copied from the children. * @param left the left child. * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @param answerColumns1 the columns of the left child to be returned. Order matters. - * @param answerColumns2 the columns of the right child to be returned. Order matters. + * @param leftCompareColumns the columns of the left child to be compared with the right. Order matters. + * @param rightCompareColumns the columns of the right child to be compared with the left. Order matters. + * @param leftAnswerColumns the columns of the left child to be returned. Order matters. + * @param rightAnswerColumns the columns of the right child to be returned. Order matters. + * @param setSemanticsLeft if the hash table of the left child should use set semantics. + * @param setSemanticsRight if the hash table of the right child should use set semantics. + * @param order the join pull order policy. * @throw IllegalArgumentException if there are duplicated column names in outputColumns, or if * outputColumns does not have the correct number of columns and column types. */ public SymmetricHashJoin( - final List outputColumns, final Operator left, final Operator right, - final int[] compareIndx1, - final int[] compareIndx2, - final int[] answerColumns1, - final int[] answerColumns2) { + final int[] leftCompareColumns, + final int[] rightCompareColumns, + final int[] leftAnswerColumns, + final int[] rightAnswerColumns, + final boolean setSemanticsLeft, + final boolean setSemanticsRight, + final List outputColumns, + final JoinPullOrder order) { super(left, right); - Preconditions.checkArgument(compareIndx1.length == compareIndx2.length); + Preconditions.checkArgument(leftCompareColumns.length == rightCompareColumns.length); if (outputColumns != null) { Preconditions.checkArgument( - outputColumns.size() == answerColumns1.length + answerColumns2.length, + outputColumns.size() == leftAnswerColumns.length + rightAnswerColumns.length, "length mismatch between output column names and columns selected for output"); Preconditions.checkArgument( ImmutableSet.copyOf(outputColumns).size() == outputColumns.size(), @@ -318,53 +132,13 @@ public SymmetricHashJoin( } else { this.outputColumns = null; } - leftCompareIndx = MyriaArrayUtils.warnIfNotSet(compareIndx1); - rightCompareIndx = MyriaArrayUtils.warnIfNotSet(compareIndx2); - leftAnswerColumns = MyriaArrayUtils.warnIfNotSet(answerColumns1); - rightAnswerColumns = MyriaArrayUtils.warnIfNotSet(answerColumns2); - } - - /** - * Construct an EquiJoin operator. It returns all columns from both children when the corresponding columns in - * compareIndx1 and compareIndx2 match. - * - * @param outputColumns the names of the columns in the output schema. If null, the corresponding columns will be - * copied from the children. - * @param left the left child. - * @param right the right child. - * @param compareIndx1 the columns of the left child to be compared with the right. Order matters. - * @param compareIndx2 the columns of the right child to be compared with the left. Order matters. - * @throw IllegalArgumentException if there are duplicated column names in outputSchema, or if - * outputSchema does not have the correct number of columns and column types. - */ - public SymmetricHashJoin( - final List outputColumns, - final Operator left, - final Operator right, - final int[] compareIndx1, - final int[] compareIndx2) { - this( - outputColumns, - left, - right, - compareIndx1, - compareIndx2, - range(left.getSchema().numColumns()), - range(right.getSchema().numColumns())); - } - - /** - * Helper function that generates an array of the numbers 0..max-1. - * - * @param max the size of the array. - * @return an array of the numbers 0..max-1. - */ - private static int[] range(final int max) { - int[] ret = new int[max]; - for (int i = 0; i < max; ++i) { - ret[i] = i; - } - return ret; + this.leftCompareColumns = MyriaArrayUtils.warnIfNotSet(leftCompareColumns); + this.rightCompareColumns = MyriaArrayUtils.warnIfNotSet(rightCompareColumns); + this.leftAnswerColumns = MyriaArrayUtils.warnIfNotSet(leftAnswerColumns); + this.rightAnswerColumns = MyriaArrayUtils.warnIfNotSet(rightAnswerColumns); + this.setSemanticsLeft = setSemanticsLeft; + this.setSemanticsRight = setSemanticsRight; + this.order = order; } @Override @@ -381,9 +155,9 @@ protected Schema generateSchema() { ImmutableList.Builder names = ImmutableList.builder(); /* Assert that the compare index types are the same. */ - for (int i = 0; i < rightCompareIndx.length; ++i) { - int leftIndex = leftCompareIndx[i]; - int rightIndex = rightCompareIndx[i]; + for (int i = 0; i < rightCompareColumns.length; ++i) { + int leftIndex = leftCompareColumns[i]; + int rightIndex = rightCompareColumns[i]; Type leftType = leftSchema.getColumnType(leftIndex); Type rightType = rightSchema.getColumnType(rightIndex); Preconditions.checkState( @@ -405,7 +179,6 @@ protected Schema generateSchema() { types.add(rightSchema.getColumnType(i)); names.add(rightSchema.getColumnName(i)); } - if (outputColumns != null) { return new Schema(types.build(), outputColumns); } else { @@ -445,8 +218,8 @@ protected void addToAns( @Override protected void cleanup() throws DbException { - hashTable1 = null; - hashTable2 = null; + leftHashTable = null; + rightHashTable = null; ans = null; } @@ -639,21 +412,12 @@ protected TupleBatch fetchNextReady() throws DbException { @Override public void init(final ImmutableMap execEnvVars) throws DbException { - final Operator left = getLeft(); - final Operator right = getRight(); - leftHashTableIndices = new IntObjectHashMap(); - rightHashTableIndices = new IntObjectHashMap(); - - hashTable1 = new MutableTupleBuffer(left.getSchema()); - hashTable2 = new MutableTupleBuffer(right.getSchema()); - + leftHashTable = new TupleHashTable(getLeft().getSchema(), leftCompareColumns); + rightHashTable = new TupleHashTable(getRight().getSchema(), rightCompareColumns); ans = new TupleBatchBuffer(getSchema()); - nonBlocking = (QueryExecutionMode) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_EXECUTION_MODE) == QueryExecutionMode.NON_BLOCKING; - doJoin = new JoinProcedure(); - doReplace = new ReplaceProcedure(); } /** @@ -668,66 +432,33 @@ public void init(final ImmutableMap execEnvVars) throws DbExcept protected void processChildTB(final TupleBatch tb, final boolean fromLeft) { final Operator left = getLeft(); final Operator right = getRight(); - - if (left.eos() && rightHashTableIndices != null) { - /* delete right child's hash table if the left child is EOS, since there will be no incoming tuples from right as - * it will never be probed again. */ - rightHashTableIndices = null; - hashTable2 = null; + /* delete one child's hash table if the other reaches EOS. */ + if (left.eos()) { + rightHashTable = null; } - if (right.eos() && leftHashTableIndices != null) { - /* delete left child's hash table if the right child is EOS, since there will be no incoming tuples from left as - * it will never be probed again. */ - leftHashTableIndices = null; - hashTable1 = null; + if (right.eos()) { + leftHashTable = null; } final boolean useSetSemantics = fromLeft && setSemanticsLeft || !fromLeft && setSemanticsRight; - MutableTupleBuffer hashTable1Local = null; - IntObjectHashMap hashTable1IndicesLocal = null; - IntObjectHashMap hashTable2IndicesLocal = null; + TupleHashTable buildHashTable = null; + TupleHashTable probeHashTable = null; + int[] buildCompareColumns = null; if (fromLeft) { - hashTable1Local = hashTable1; - doJoin.joinAgainstHashTable = hashTable2; - hashTable1IndicesLocal = leftHashTableIndices; - hashTable2IndicesLocal = rightHashTableIndices; - doJoin.inputCmpColumns = leftCompareIndx; - doJoin.joinAgainstCmpColumns = rightCompareIndx; - if (useSetSemantics) { - doReplace.hashTable = hashTable1; - doReplace.keyColumns = leftCompareIndx; - } + buildHashTable = leftHashTable; + probeHashTable = rightHashTable; + buildCompareColumns = leftCompareColumns; } else { - hashTable1Local = hashTable2; - doJoin.joinAgainstHashTable = hashTable1; - hashTable1IndicesLocal = rightHashTableIndices; - hashTable2IndicesLocal = leftHashTableIndices; - doJoin.inputCmpColumns = rightCompareIndx; - doJoin.joinAgainstCmpColumns = leftCompareIndx; - if (useSetSemantics) { - doReplace.hashTable = hashTable2; - doReplace.keyColumns = rightCompareIndx; - } + buildHashTable = rightHashTable; + probeHashTable = leftHashTable; + buildCompareColumns = rightCompareColumns; } - doJoin.fromLeft = fromLeft; - doJoin.inputTB = tb; - if (useSetSemantics) { - doReplace.inputTB = tb; - } - for (int row = 0; row < tb.numTuples(); ++row) { - final int cntHashCode = HashUtils.hashSubRow(tb, doJoin.inputCmpColumns, row); - IntArrayList tuplesWithHashCode = hashTable2IndicesLocal.get(cntHashCode); - if (tuplesWithHashCode != null) { - doJoin.row = row; - tuplesWithHashCode.forEach(doJoin); - } - - if (hashTable1Local != null) { - // only build hash table on two sides if none of the children is EOS - addToHashTable( - tb, row, hashTable1Local, hashTable1IndicesLocal, cntHashCode, useSetSemantics); + IntIterator iter = probeHashTable.getIndices(tb, buildCompareColumns, row).intIterator(); + while (iter.hasNext()) { + addToAns(tb, row, probeHashTable.getData(), iter.next(), fromLeft); } + addToHashTable(tb, buildCompareColumns, row, buildHashTable, useSetSemantics); } } @@ -737,48 +468,32 @@ protected void processChildTB(final TupleBatch tb, final boolean fromLeft) { * @param hashTable the target hash table * @param hashTable1IndicesLocal hash table 1 indices local * @param hashCode the hashCode of the tb. - * @param useSetSemantics if need to update the hash table using set semantics. + * @param replace if need to replace the hash table with new values. */ private void addToHashTable( final TupleBatch tb, + final int[] compareIndx, final int row, - final MutableTupleBuffer hashTable, - final IntObjectHashMap hashTable1IndicesLocal, - final int hashCode, - final boolean useSetSemantics) { - - final int nextIndex = hashTable.numTuples(); - IntArrayList tupleIndicesList = hashTable1IndicesLocal.get(hashCode); - if (tupleIndicesList == null) { - tupleIndicesList = new IntArrayList(1); - hashTable1IndicesLocal.put(hashCode, tupleIndicesList); - } - - doReplace.replaced = false; - if (useSetSemantics) { - doReplace.row = row; - tupleIndicesList.forEach(doReplace); - } - if (!doReplace.replaced) { - /* not using set semantics || using set semantics but found nothing to replace (i.e. new) */ - tupleIndicesList.add(nextIndex); - List> inputColumns = tb.getDataColumns(); - for (int column = 0; column < tb.numColumns(); column++) { - hashTable.put(column, inputColumns.get(column), row); + final TupleHashTable hashTable, + final boolean replace) { + if (replace) { + if (hashTable.replace(tb, compareIndx, row)) { + return; } } + hashTable.addTuple(tb, compareIndx, row, false); } /** - * @return the sum of the numbers of tuples in both hash tables. + * @return the total number of tuples in hash tables */ public long getNumTuplesInHashTables() { long sum = 0; - if (hashTable1 != null) { - sum += hashTable1.numTuples(); + if (leftHashTable != null) { + sum += leftHashTable.numTuples(); } - if (hashTable2 != null) { - sum += hashTable2.numTuples(); + if (rightHashTable != null) { + sum += rightHashTable.numTuples(); } return sum; } @@ -796,13 +511,4 @@ public enum JoinPullOrder { /** Pull from the right child until it reaches EOS. */ RIGHT_EOS } - - /** - * Set the pull order. - * - * @param order the pull order. - */ - public void setPullOrder(final JoinPullOrder order) { - this.order = order; - } } diff --git a/src/edu/washington/escience/myria/operator/TupleHashTable.java b/src/edu/washington/escience/myria/operator/TupleHashTable.java new file mode 100644 index 000000000..975293c8c --- /dev/null +++ b/src/edu/washington/escience/myria/operator/TupleHashTable.java @@ -0,0 +1,131 @@ +package edu.washington.escience.myria.operator; + +import java.io.Serializable; + +import com.gs.collections.api.iterator.IntIterator; +import com.gs.collections.impl.list.mutable.primitive.IntArrayList; +import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; + +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleUtils; +import edu.washington.escience.myria.util.HashUtils; + +/** + * An abstraction of a hash table of tuples. + */ +public final class TupleHashTable implements Serializable { + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** Map from hash codes to indices. */ + private transient IntObjectHashMap keyHashCodesToIndices; + /** The table containing keys and values. */ + private transient MutableTupleBuffer data; + /** Key column indices. */ + private final int[] keyColumns; + + /** + * @param schema schema + * @param keyColumns key column indices + */ + public TupleHashTable(final Schema schema, final int[] keyColumns) { + this.keyColumns = keyColumns; + data = new MutableTupleBuffer(schema); + keyHashCodesToIndices = new IntObjectHashMap(); + } + + /** + * @return the number of tuples this hash table has. + */ + public int numTuples() { + return data.numTuples(); + } + + /** + * Get the data table indices given key columns from a tuple in a tuple batch. + * + * @param tb the input tuple batch + * @param key the key columns + * @param row the row index of the tuple + * @return the indices + */ + public IntArrayList getIndices(final TupleBatch tb, final int[] key, final int row) { + IntArrayList ret = new IntArrayList(); + IntArrayList indices = keyHashCodesToIndices.get(HashUtils.hashSubRow(tb, key, row)); + if (indices != null) { + IntIterator iter = indices.intIterator(); + while (iter.hasNext()) { + int i = iter.next(); + if (TupleUtils.tupleEquals(tb, key, row, data, keyColumns, i)) { + ret.add(i); + } + } + } + return ret; + } + + /** + * Replace tuples in the hash table with the input tuple if they have the same key. + * + * @param tb the input tuple batch + * @param keyColumns the key columns + * @param row the row index of the input tuple + * @return if at least one tuple is replaced + */ + public boolean replace(final TupleBatch tb, final int[] keyColumns, final int row) { + IntIterator iter = getIndices(tb, keyColumns, row).intIterator(); + if (!iter.hasNext()) { + return false; + } + while (iter.hasNext()) { + int i = iter.next(); + for (int j = 0; j < data.numColumns(); ++j) { + data.replace(j, i, tb.getDataColumns().get(j), row); + } + } + return true; + } + + /** + * @param tb tuple batch of the input tuple + * @param keyColumns key column indices + * @param row row index of the input tuple + * @param keyOnly only add keyColumns + */ + public void addTuple( + final TupleBatch tb, final int[] keyColumns, final int row, final boolean keyOnly) { + int hashcode = HashUtils.hashSubRow(tb, keyColumns, row); + IntArrayList indices = keyHashCodesToIndices.get(hashcode); + if (indices == null) { + indices = new IntArrayList(); + keyHashCodesToIndices.put(hashcode, indices); + } + indices.add(numTuples()); + if (keyOnly) { + for (int i = 0; i < keyColumns.length; ++i) { + data.put(i, tb.getDataColumns().get(keyColumns[i]), row); + } + } else { + for (int i = 0; i < data.numColumns(); ++i) { + data.put(i, tb.getDataColumns().get(i), row); + } + } + } + + /** + * @return the data + */ + public MutableTupleBuffer getData() { + return data; + } + + /** + * Clean up the hash table. + */ + public void cleanup() { + keyHashCodesToIndices = new IntObjectHashMap(); + data = new MutableTupleBuffer(data.getSchema()); + } +} diff --git a/src/edu/washington/escience/myria/operator/agg/AggUtils.java b/src/edu/washington/escience/myria/operator/agg/AggUtils.java index 8f67482f8..7b2584fd0 100644 --- a/src/edu/washington/escience/myria/operator/agg/AggUtils.java +++ b/src/edu/washington/escience/myria/operator/agg/AggUtils.java @@ -82,40 +82,4 @@ public static boolean needsMax(final Set aggOps) { public static boolean needsStats(final Set aggOps) { return !Sets.intersection(STATS_OPS, aggOps).isEmpty(); } - - /** - * Utility class to allocate a set of aggregators from the factories. - * - * @param factories The factories that will produce the aggregators. - * @param inputSchema The schema of the input tuples. - * @param pyFuncRegistrar - python function registrar to get functions from workers. - * @return the aggregators for this operator. - * @throws DbException if there is an error. - */ - public static Aggregator[] allocateAggs( - final AggregatorFactory[] factories, - final Schema inputSchema, - final PythonFunctionRegistrar pyFuncRegistrar) - throws DbException { - Aggregator[] aggregators = new Aggregator[factories.length]; - for (int j = 0; j < factories.length; ++j) { - - aggregators[j] = factories[j].get(inputSchema, pyFuncRegistrar); - } - return aggregators; - } - - /** - * Utility class to allocate the initial aggregation states from a set of {@link Aggregator}s. - * - * @param aggregators the {@link Aggregator}s that will update the states. - * @return the initial aggregation states for the specified {@link Aggregator}s. - */ - public static Object[] allocateAggStates(final Aggregator[] aggregators) { - Object[] states = new Object[aggregators.length]; - for (int j = 0; j < aggregators.length; ++j) { - states[j] = aggregators[j].getInitialState(); - } - return states; - } } diff --git a/src/edu/washington/escience/myria/operator/agg/Aggregate.java b/src/edu/washington/escience/myria/operator/agg/Aggregate.java index bc05348c3..b39d9caec 100644 --- a/src/edu/washington/escience/myria/operator/agg/Aggregate.java +++ b/src/edu/washington/escience/myria/operator/agg/Aggregate.java @@ -1,117 +1,210 @@ package edu.washington.escience.myria.operator.agg; -import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; import javax.annotation.Nullable; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.gs.collections.api.iterator.IntIterator; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.column.Column; +import edu.washington.escience.myria.expression.Expression; +import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; +import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; +import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; +import edu.washington.escience.myria.functions.PythonFunctionRegistrar; import edu.washington.escience.myria.operator.Operator; +import edu.washington.escience.myria.operator.TupleHashTable; import edu.washington.escience.myria.operator.UnaryOperator; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; +import edu.washington.escience.myria.util.MyriaArrayUtils; /** - * The Aggregation operator that computes an aggregate. - * - * This class does not do group by. + * The Aggregation operator that computes an aggregate (e.g., sum, avg, max, min). This variant supports aggregates over + * multiple columns, group by multiple columns. */ -public final class Aggregate extends UnaryOperator { +public class Aggregate extends UnaryOperator { - /** Required for Java serialization. */ + /** Java requires this. **/ private static final long serialVersionUID = 1L; - /** Use to create the aggregators. */ + /** The hash table containing groups and states. */ + protected transient TupleHashTable groupStates; + /** Factories to make the Aggregators. **/ private final AggregatorFactory[] factories; - /** The actual aggregators. */ - private Aggregator[] aggregators; - /** The state of the aggregators. */ - private Object[] aggregatorStates; - /** - * buffer for holding results. - */ - private transient TupleBatchBuffer aggBuffer; + /** Aggregators of the internal state. */ + protected List internalAggs; + /** Expressions that emit output. */ + protected List emitEvals; + /** Group fields. Empty array means no grouping. **/ + protected final int[] gfields; + /** Buffer for restoring results. */ + protected TupleBatchBuffer resultBuffer; /** - * Computes the value of one or more aggregates over the entire input relation. + * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. * * @param child The Operator that is feeding us tuples. - * @param aggregators The {@link AggregatorFactory}s that creators the {@link Aggregator}s. + * @param gfields The columns over which we are grouping the result. Null means no group by. + * @param factories The factories that will produce the {@link Aggregator}s; */ - public Aggregate(@Nullable final Operator child, final AggregatorFactory... aggregators) { + public Aggregate( + @Nullable final Operator child, final int[] gfields, final AggregatorFactory... factories) { super(child); - Preconditions.checkNotNull(aggregators, "aggregators"); - int i = 0; - for (AggregatorFactory agg : aggregators) { - Preconditions.checkNotNull(agg, "aggregators[%s]", i); - ++i; - } - factories = aggregators; + this.gfields = gfields; + this.factories = Objects.requireNonNull(factories, "factories"); } @Override - protected TupleBatch fetchNextReady() throws DbException, IOException { - TupleBatch tb = null; - final Operator child = getChild(); + protected void cleanup() throws DbException { + groupStates.cleanup(); + resultBuffer.clear(); + } + /** + * Returns the next tuple. The first few columns are group-by fields if there are any, followed by columns of + * aggregate results generated by {@link Aggregate#emitEvals}. + * + * @throws DbException if any error occurs. + * @return result TB. + */ + @Override + protected TupleBatch fetchNextReady() throws DbException { + final Operator child = getChild(); + TupleBatch tb = child.nextReady(); + while (tb != null) { + for (int row = 0; row < tb.numTuples(); ++row) { + IntIterator iter = groupStates.getIndices(tb, gfields, row).intIterator(); + int index; + if (!iter.hasNext()) { + groupStates.addTuple(tb, gfields, row, true); + int offset = gfields.length; + for (Aggregator agg : internalAggs) { + agg.initState(groupStates.getData(), offset); + offset += agg.getStateSize(); + } + index = groupStates.getData().numTuples() - 1; + } else { + index = iter.next(); + } + int offset = gfields.length; + for (Aggregator agg : internalAggs) { + agg.addRow(tb, row, groupStates.getData(), index, offset); + offset += agg.getStateSize(); + } + } + tb = child.nextReady(); + } if (child.eos()) { - return aggBuffer.popAny(); + generateResult(); + return resultBuffer.popAny(); } + return null; + } - while ((tb = child.nextReady()) != null) { - for (int agg = 0; agg < aggregators.length; ++agg) { - aggregators[agg].add(tb, aggregatorStates[agg]); + /** + * @return A batch's worth of result tuples from this aggregate. + * @throws DbException if there is an error. + */ + protected void generateResult() throws DbException { + if (groupStates.numTuples() == 0) { + return; + } + int stateOffset = gfields.length; + for (Aggregator agg : internalAggs) { + if (agg instanceof UserDefinedAggregator) { + ((UserDefinedAggregator) agg).finalizePythonUpdaters(groupStates.getData(), stateOffset); } + stateOffset += agg.getStateSize(); } - - if (child.eos()) { - int fromIndex = 0; - for (int agg = 0; agg < aggregators.length; ++agg) { - aggregators[agg].getResult(aggBuffer, fromIndex, aggregatorStates[agg]); - fromIndex += aggregators[agg].getResultSchema().numColumns(); + Schema inputSchema = getChild().getSchema(); + for (TupleBatch tb : groupStates.getData().getAll()) { + List> columns = new ArrayList>(); + columns.addAll(tb.getDataColumns().subList(0, gfields.length)); + stateOffset = gfields.length; + int emitOffset = 0; + for (AggregatorFactory factory : factories) { + int stateSize = factory.generateStateSchema(inputSchema).numColumns(); + int emitSize = factory.generateSchema(inputSchema).numColumns(); + TupleBatch state = tb.selectColumns(MyriaArrayUtils.range(stateOffset, stateSize)); + for (GenericEvaluator eval : emitEvals.subList(emitOffset, emitOffset + emitSize)) { + columns.add(eval.evalTupleBatch(state, getSchema()).getResultColumns().get(0)); + } + stateOffset += stateSize; + emitOffset += emitSize; } - return aggBuffer.popAny(); + addToResult(columns); } - return null; + groupStates.cleanup(); } - @Override - protected void init(final ImmutableMap execEnvVars) throws DbException { - Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); - aggregators = - AggUtils.allocateAggs(factories, getChild().getSchema(), getPythonFunctionRegistrar()); - aggregatorStates = AggUtils.allocateAggStates(aggregators); - aggBuffer = new TupleBatchBuffer(getSchema()); + /** + * @param columns result columns. + */ + protected void addToResult(List> columns) { + resultBuffer.absorb(new TupleBatch(getSchema(), columns), true); } - + /** + * The schema of the aggregate output. Grouping fields first and then aggregate fields. The aggregate + * + * @return the resulting schema + */ @Override protected Schema generateSchema() { if (getChild() == null) { return null; } - final Schema inputSchema = getChild().getSchema(); - if (inputSchema == null) { - return null; + Schema inputSchema = getChild().getSchema(); + Schema aggSchema = Schema.EMPTY_SCHEMA; + for (int i = 0; i < factories.length; ++i) { + aggSchema = Schema.merge(aggSchema, factories[i].generateSchema(inputSchema)); } + return Schema.merge(inputSchema.getSubSchema(gfields), aggSchema); + } - final ImmutableList.Builder gTypes = ImmutableList.builder(); - final ImmutableList.Builder gNames = ImmutableList.builder(); - - try { - for (Aggregator agg : - AggUtils.allocateAggs(factories, inputSchema, getPythonFunctionRegistrar())) { - Schema s = agg.getResultSchema(); - gTypes.addAll(s.getColumnTypes()); - gNames.addAll(s.getColumnNames()); + @Override + protected void init(final ImmutableMap execEnvVars) throws DbException { + Schema inputSchema = getChild().getSchema(); + Preconditions.checkState(inputSchema != null, "unable to determine schema in init"); + internalAggs = new ArrayList(); + emitEvals = new ArrayList(); + Schema groupingSchema = inputSchema.getSubSchema(gfields); + Schema stateSchema = Schema.EMPTY_SCHEMA; + PythonFunctionRegistrar pyFuncReg = getPythonFunctionRegistrar(); + for (AggregatorFactory factory : factories) { + factory.setPyFuncReg(pyFuncReg); + internalAggs.addAll(factory.generateInternalAggs(inputSchema)); + List emits = factory.generateEmitExpressions(inputSchema); + Schema newStateSchema = factory.generateStateSchema(inputSchema); + stateSchema = Schema.merge(stateSchema, newStateSchema); + for (Expression exp : emits) { + GenericEvaluator evaluator = null; + if (exp.isRegisteredPythonUDF()) { + evaluator = + new PythonUDFEvaluator( + exp, + new ExpressionOperatorParameter( + inputSchema, stateSchema, getPythonFunctionRegistrar())); + } else { + evaluator = + new GenericEvaluator( + exp, + new ExpressionOperatorParameter( + newStateSchema, newStateSchema, getPythonFunctionRegistrar())); + } + evaluator.compile(); + emitEvals.add(evaluator); } - } catch (DbException e) { - throw new RuntimeException("unable to allocate aggregators", e); } - return new Schema(gTypes, gNames); + groupStates = + new TupleHashTable( + Schema.merge(groupingSchema, stateSchema), MyriaArrayUtils.range(0, gfields.length)); + resultBuffer = new TupleBatchBuffer(getSchema()); } -} +}; diff --git a/src/edu/washington/escience/myria/operator/agg/Aggregator.java b/src/edu/washington/escience/myria/operator/agg/Aggregator.java index 992d144e7..0ba709c75 100644 --- a/src/edu/washington/escience/myria/operator/agg/Aggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/Aggregator.java @@ -1,67 +1,44 @@ package edu.washington.escience.myria.operator.agg; -import java.io.IOException; import java.io.Serializable; -import java.util.List; import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; import edu.washington.escience.myria.storage.TupleBatch; /** - * The interface for any aggregation. + * The interface for any aggregator. */ public interface Aggregator extends Serializable { /** - * Update this aggregate using all rows of the specified table. + * Update the aggregate state using the specified row of the specified table. * - * @param from the source {@link ReadableTable}. - * @param state the initial state of the aggregate, which will be mutated. - * @throws DbException if there is an error. + * @param from the MutableTupleBuffer containing the source tuple. + * @param fromRow the row index of the source tuple. + * @param to the MutableTupleBuffer containing the state. + * @param toRow the row index of the state. + * @throws DbException */ - void add(ReadableTable from, Object state) throws DbException; + public abstract void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) + throws DbException; /** - * Update this aggregate using the specified row of the specified table. - * - * @param from the source {@link ReadableTable}. - * @param row the specified row. - * @param state the initial state of the aggregate, which will be mutated. - * @throws DbException if there is an error. - */ - void addRow(ReadableTable from, int row, Object state) throws DbException; - - /** - * Append the aggregate result(s) to the given table starting from the given column. - * - * @param dest where to store the aggregate result. - * @param destColumn the starting index into which aggregates will be output. - * @param state the initial state of the aggregate, which will be mutated. - * @throws DbException if there is an error. - * @throws IOException in case of error. + * @return the size of the state schema */ - void getResult(AppendableTable dest, int destColumn, Object state) throws DbException; + int getStateSize(); /** - * Compute and return the initial state tuple for instances of this {@link Aggregator}. + * Initialize a new state by appending initial values to a new row. * - * @return the initial state tuple for instances of this {@link Aggregator}. - */ - Object getInitialState(); - - /** - * Compute and return the schema of the outputs of this {@link Aggregator}. - * - * @return the schema of the outputs of this {@link Aggregator}. - */ - Schema getResultSchema(); - /** - * @param from list of tuple batch to aggregate. - * @param state object to which state is written. - * @throws DbException in case of error. + * @param state the table containing internal states + * @param offset the column index of state to start from + * @throws DbException */ - void add(List from) throws DbException; + void initState(final MutableTupleBuffer state, final int offset) throws DbException; } diff --git a/src/edu/washington/escience/myria/operator/agg/AggregatorFactory.java b/src/edu/washington/escience/myria/operator/agg/AggregatorFactory.java index 52358f62e..670168be4 100644 --- a/src/edu/washington/escience/myria/operator/agg/AggregatorFactory.java +++ b/src/edu/washington/escience/myria/operator/agg/AggregatorFactory.java @@ -1,8 +1,7 @@ package edu.washington.escience.myria.operator.agg; import java.io.Serializable; - -import javax.annotation.Nonnull; +import java.util.List; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonSubTypes.Type; @@ -10,6 +9,7 @@ import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.functions.PythonFunctionRegistrar; /** @@ -17,25 +17,45 @@ */ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type") @JsonSubTypes({ - @Type(value = CountAllAggregatorFactory.class, name = "CountAll"), - @Type(value = SingleColumnAggregatorFactory.class, name = "SingleColumn"), + @Type(value = PrimitiveAggregatorFactory.class, name = "Primitive"), @Type(value = UserDefinedAggregatorFactory.class, name = "UserDefined") }) public interface AggregatorFactory extends Serializable { /** - * Create a new aggregator for tuples of the specified schema. + * Generate aggregators for emitting outputs. * * @param inputSchema the schema that incoming tuples will take. - * @return a new aggregator for tuples of the specified schema. - * @throws DbException if there is an error creating the aggregator. + * @return aggregators for emitting outputs. + * @throws DbException if there is an error creating the aggregators + */ + List generateEmitExpressions(final Schema inputSchema) throws DbException; + + /** + * Generate internal aggregators, each aggregator corresponds to a column of the internal state. + * + * @param inputSchema the schema that incoming tuples will take. + * @return aggregators aggregators for internal state. + * @throws DbException if there is an error creating the aggregators + */ + List generateInternalAggs(final Schema inputSchema) throws DbException; + + // PythonFunctionRegistrar pyFuncRegistrar + + /** + * Generate the schema of the internal state of aggregators generated by this factory. + * + * @param inputSchema the {@link Schema} of the input tuples. + * @return the internal state schema. */ - @Nonnull - Aggregator get(Schema inputSchema) throws DbException; + Schema generateStateSchema(final Schema inputSchema); + /** - * @param inputSchema input schema for aggregator. - * @param pyFuncReg python function registrar. - * @return a new aggregator for tuples of the specified schema. - * @throws DbException in case of error. + * Generate the output schema of the aggregators generated by this factory. + * + * @param inputSchema the {@link Schema} of the input tuples. + * @return the output schema. */ - Aggregator get(Schema inputSchema, PythonFunctionRegistrar pyFuncReg) throws DbException; + Schema generateSchema(final Schema inputSchema); + + void setPyFuncReg(PythonFunctionRegistrar pyFuncReg); } diff --git a/src/edu/washington/escience/myria/operator/agg/BooleanAggregator.java b/src/edu/washington/escience/myria/operator/agg/BooleanAggregator.java index 31ec37371..233cc6cb6 100644 --- a/src/edu/washington/escience/myria/operator/agg/BooleanAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/BooleanAggregator.java @@ -1,112 +1,67 @@ package edu.washington.escience.myria.operator.agg; import java.util.Objects; -import java.util.Set; -import com.google.common.collect.ImmutableSet; -import com.google.common.math.LongMath; - -import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregates over a BooleanColumn. */ public final class BooleanAggregator extends PrimitiveAggregator { + protected BooleanAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); + } + /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for boolean columns. - */ - public static final Set AVAILABLE_AGG = ImmutableSet.of(AggregationOp.COUNT); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public BooleanAggregator( - final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; - } @Override - public void add(final ReadableTable from, final Object state) { + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { Objects.requireNonNull(from, "from"); - BooleanAggState b = (BooleanAggState) state; - b.count += from.numTuples(); - } - - /** - * Add the specified value to this aggregator. - * - * @param value the value to be added. - * @param state the current state of the aggregate. - */ - public void addBoolean(final boolean value, final Object state) { - BooleanAggState b = (BooleanAggState) state; - if (needsCount) { - b.count = LongMath.checkedAdd(b.count, 1); - } - } - - @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - BooleanAggState b = (BooleanAggState) state; - Objects.requireNonNull(dest, "dest"); - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case COUNT: - dest.putLong(idx, b.count); - break; - case AVG: - case MAX: - case MIN: - case STDEV: - case SUM: - throw new UnsupportedOperationException("Aggregate " + op + " on type Boolean"); - } - idx++; + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - public Type getType() { - return Type.BOOLEAN_TYPE; + protected boolean isSupported(final AggregationOp aggOp) { + return aggOp.equals(AggregationOp.COUNT); } @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } - - @Override - protected Type getSumType() { - throw new UnsupportedOperationException("SUM of Boolean values"); - } - - @Override - public void addRow(final ReadableTable from, final int row, final Object state) - throws DbException { - addBoolean(from.getBoolean(fromColumn, row), state); - } + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + return Type.LONG_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } + }; @Override - public Object getInitialState() { - return new BooleanAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class BooleanAggState { - /** The number of tuples seen so far. */ - private long count = 0; + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + data.putLong(column, 0); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/CountAllAggregator.java b/src/edu/washington/escience/myria/operator/agg/CountAllAggregator.java deleted file mode 100644 index a4e15de61..000000000 --- a/src/edu/washington/escience/myria/operator/agg/CountAllAggregator.java +++ /dev/null @@ -1,64 +0,0 @@ -package edu.washington.escience.myria.operator.agg; - -import java.util.List; - -import com.google.common.math.LongMath; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.TupleBatch; - -/** - * An aggregator that counts the number of rows in its input. - */ -public final class CountAllAggregator implements Aggregator { - - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - /** The schema of the aggregate results. */ - public static final Schema SCHEMA = Schema.ofFields(Type.LONG_TYPE, "count_all"); - - @Override - public void add(final ReadableTable from, final Object state) throws DbException { - CountAllState c = (CountAllState) state; - c.count = LongMath.checkedAdd(c.count, from.numTuples()); - } - - @Override - public void addRow(final ReadableTable from, final int row, final Object state) - throws DbException { - CountAllState c = (CountAllState) state; - c.count = LongMath.checkedAdd(c.count, 1); - } - - @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) - throws DbException { - CountAllState c = (CountAllState) state; - dest.putLong(destColumn, c.count); - } - - @Override - public Schema getResultSchema() { - return SCHEMA; - } - - @Override - public Object getInitialState() { - return new CountAllState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class CountAllState { - /** The number of tuples seen so far. */ - private long count = 0; - } - - @Override - public void add(final List from) throws DbException { - throw new DbException(" method not implemented"); - } -} diff --git a/src/edu/washington/escience/myria/operator/agg/CountAllAggregatorFactory.java b/src/edu/washington/escience/myria/operator/agg/CountAllAggregatorFactory.java deleted file mode 100644 index 4a4e9994e..000000000 --- a/src/edu/washington/escience/myria/operator/agg/CountAllAggregatorFactory.java +++ /dev/null @@ -1,30 +0,0 @@ -package edu.washington.escience.myria.operator.agg; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.functions.PythonFunctionRegistrar; - -/** - * A factory for the CountAll aggregator. - */ -public final class CountAllAggregatorFactory implements AggregatorFactory { - - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - - /** Instantiate a CountAllFactory. */ - public CountAllAggregatorFactory() { - /** Nothing needed here. */ - } - - @Override - public Aggregator get(final Schema inputSchema) throws DbException { - return new CountAllAggregator(); - } - - @Override - public Aggregator get(final Schema inputSchema, PythonFunctionRegistrar pyFuncReg) - throws DbException { - return get(inputSchema); - } -} diff --git a/src/edu/washington/escience/myria/operator/agg/DateTimeAggregator.java b/src/edu/washington/escience/myria/operator/agg/DateTimeAggregator.java index 21a453c5a..88f03858f 100644 --- a/src/edu/washington/escience/myria/operator/agg/DateTimeAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/DateTimeAggregator.java @@ -1,156 +1,99 @@ package edu.washington.escience.myria.operator.agg; import java.util.Objects; -import java.util.Set; import org.joda.time.DateTime; import com.google.common.collect.ImmutableSet; -import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregate over a DateTimeColumn. */ public final class DateTimeAggregator extends PrimitiveAggregator { - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for string columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of(AggregationOp.COUNT, AggregationOp.MAX, AggregationOp.MIN); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public DateTimeAggregator( - final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; + protected DateTimeAggregator( + final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); } - /** - * Add the specified value to this aggregator. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - public void addDateTime(final DateTime value, final Object state) { - Objects.requireNonNull(value, "value"); - DateTimeAggState d = (DateTimeAggState) state; - if (needsCount) { - d.count = LongMath.checkedAdd(d.count, 1); - } - if (needsStats) { - addDateTimeStats(value, d); - } - } + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; @Override - public void add(final ReadableTable from, final Object state) { + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { Objects.requireNonNull(from, "from"); - DateTimeAggState d = (DateTimeAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - if (needsCount) { - d.count = LongMath.checkedAdd(d.count, numTuples); - } - if (needsStats) { - for (int row = 0; row < numTuples; ++row) { - addDateTimeStats(from.getDateTime(fromColumn, row), d); - } - } - } - - @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - addDateTime(Objects.requireNonNull(table, "table").getDateTime(fromColumn, row), state); - } - - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addDateTimeStats(final DateTime value, final DateTimeAggState state) { - Objects.requireNonNull(value, "value"); - if (needsMin) { - if ((state.min == null) || (state.min.compareTo(value) > 0)) { - state.min = value; - } - } - if (needsMax) { - if ((state.max == null) || (state.max.compareTo(value) < 0)) { - state.max = value; - } - } - } - - @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - DateTimeAggState d = (DateTimeAggState) state; - Objects.requireNonNull(dest, "dest"); - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case COUNT: - dest.putLong(idx, d.count); - break; - case MAX: - dest.putDateTime(idx, d.max); + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + { + DateTime value = fromCol.getDateTime(fromRow); + if (value.compareTo(toCol.getDateTime(inColumnRow)) > 0) { + toCol.replaceDateTime(value, inColumnRow); + } break; - case MIN: - dest.putDateTime(idx, d.min); + } + case MIN: + { + DateTime value = fromCol.getDateTime(fromRow); + if (value.compareTo(toCol.getDateTime(inColumnRow)) < 0) { + toCol.replaceDateTime(value, inColumnRow); + } break; - case AVG: - case STDEV: - case SUM: - throw new UnsupportedOperationException("Aggregate " + op + " on type DateTime"); - } - idx++; + } + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - protected Type getSumType() { - throw new UnsupportedOperationException("SUM of DateTime values"); - } - - @Override - public Type getType() { - return Type.DATETIME_TYPE; - } - - @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of(AggregationOp.COUNT, AggregationOp.MIN, AggregationOp.MAX) + .contains(aggOp); } @Override - public Object getInitialState() { - return new DateTimeAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class DateTimeAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private DateTime min = null; - /** The maximum value in the aggregated column. */ - private DateTime max = null; + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + return Type.LONG_TYPE; + case MAX: + case MIN: + return Type.DATETIME_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } + }; + + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + data.putLong(column, 0); + break; + case MAX: + data.putDateTime(column, new DateTime(Long.MIN_VALUE)); + break; + case MIN: + data.putDateTime(column, new DateTime(Long.MAX_VALUE)); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/DoubleAggregator.java b/src/edu/washington/escience/myria/operator/agg/DoubleAggregator.java index 9fd700bc7..0fa12ce0e 100644 --- a/src/edu/washington/escience/myria/operator/agg/DoubleAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/DoubleAggregator.java @@ -1,163 +1,108 @@ package edu.washington.escience.myria.operator.agg; -import java.util.Objects; -import java.util.Set; - import com.google.common.collect.ImmutableSet; -import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregates over a DoubleColumn. */ public final class DoubleAggregator extends PrimitiveAggregator { + protected DoubleAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); + } + /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for double columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of( - AggregationOp.COUNT, - AggregationOp.SUM, - AggregationOp.MAX, - AggregationOp.MIN, - AggregationOp.AVG, - AggregationOp.STDEV); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public DoubleAggregator(final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; - } @Override - public void add(final ReadableTable from, final Object state) { - Objects.requireNonNull(from, "from"); - DoubleAggState d = (DoubleAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - if (needsCount) { - d.count = LongMath.checkedAdd(d.count, numTuples); - } - if (!needsStats) { - return; - } - for (int i = 0; i < numTuples; i++) { - addDoubleStats(from.getDouble(fromColumn, i), d); + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + toCol.replaceDouble( + Math.max(fromCol.getDouble(fromRow), toCol.getDouble(inColumnRow)), inColumnRow); + break; + case MIN: + toCol.replaceDouble( + Math.min(fromCol.getDouble(fromRow), toCol.getDouble(inColumnRow)), inColumnRow); + break; + case SUM: + toCol.replaceDouble(fromCol.getDouble(fromRow) + toCol.getDouble(inColumnRow), inColumnRow); + break; + case SUM_SQUARED: + toCol.replaceDouble( + fromCol.getDouble(fromRow) * fromCol.getDouble(fromRow) + toCol.getDouble(inColumnRow), + inColumnRow); + break; + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - Objects.requireNonNull(table, "table"); - DoubleAggState d = (DoubleAggState) state; - if (needsCount) { - d.count = LongMath.checkedAdd(d.count, 1); - } - if (needsStats) { - addDoubleStats(table.getDouble(fromColumn, row), d); - } - } - - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addDoubleStats(final double value, final DoubleAggState state) { - if (needsSum) { - state.sum += value; - } - if (needsSumSq) { - state.sumSquared += value * value; - } - if (needsMin) { - state.min = Math.min(state.min, value); - } - if (needsMax) { - state.max = Math.max(state.max, value); - } + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of( + AggregationOp.COUNT, + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV, + AggregationOp.SUM_SQUARED) + .contains(aggOp); } @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - Objects.requireNonNull(dest, "dest"); - DoubleAggState d = (DoubleAggState) state; - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case AVG: - dest.putDouble(idx, d.sum * 1.0 / d.count); - break; - case COUNT: - dest.putLong(idx, d.count); - break; - case MAX: - dest.putDouble(idx, d.max); - break; - case MIN: - dest.putDouble(idx, d.min); - break; - case STDEV: - double first = d.sumSquared / d.count; - double second = d.sum / d.count; - double stdev = Math.sqrt(first - second * second); - dest.putDouble(idx, stdev); - break; - case SUM: - dest.putDouble(idx, d.sum); - break; - } - idx++; + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + return Type.LONG_TYPE; + case MAX: + case MIN: + case SUM: + case AVG: + case STDEV: + return Type.DOUBLE_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); } - } - - @Override - public Type getType() { - return Type.DOUBLE_TYPE; - } + }; @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } - - @Override - protected Type getSumType() { - return Type.DOUBLE_TYPE; - } - - @Override - public Object getInitialState() { - return new DoubleAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class DoubleAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private double min = Double.MAX_VALUE; - /** The maximum value in the aggregated column. */ - private double max = Double.MIN_VALUE; - /** The sum of values in the aggregated column. */ - private double sum = 0; - /** private temp variables for computing stdev. */ - private double sumSquared = 0; + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + data.putLong(column, 0); + break; + case SUM: + case SUM_SQUARED: + data.putDouble(column, 0); + break; + case MAX: + data.putDouble(column, Double.MIN_VALUE); + break; + case MIN: + data.putDouble(column, Double.MAX_VALUE); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/FloatAggregator.java b/src/edu/washington/escience/myria/operator/agg/FloatAggregator.java index a9babf95f..78d5a2937 100644 --- a/src/edu/washington/escience/myria/operator/agg/FloatAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/FloatAggregator.java @@ -1,165 +1,110 @@ package edu.washington.escience.myria.operator.agg; -import java.util.Objects; -import java.util.Set; - import com.google.common.collect.ImmutableSet; -import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregates over a FloatColumn. */ public final class FloatAggregator extends PrimitiveAggregator { + protected FloatAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); + } + /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for float columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of( - AggregationOp.COUNT, - AggregationOp.SUM, - AggregationOp.MAX, - AggregationOp.MIN, - AggregationOp.AVG, - AggregationOp.STDEV); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public FloatAggregator(final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; - } @Override - public void add(final ReadableTable from, final Object state) { - Objects.requireNonNull(from, "from"); - FloatAggState f = (FloatAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - - if (needsCount) { - f.count = LongMath.checkedAdd(f.count, numTuples); - } - - if (!needsStats) { - return; - } - for (int i = 0; i < numTuples; i++) { - addFloatStats(from.getFloat(fromColumn, i), f); + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + toCol.replaceFloat( + Math.max(fromCol.getFloat(fromRow), toCol.getFloat(inColumnRow)), inColumnRow); + break; + case MIN: + toCol.replaceFloat( + Math.min(fromCol.getFloat(fromRow), toCol.getFloat(inColumnRow)), inColumnRow); + break; + case SUM: + toCol.replaceDouble(fromCol.getFloat(fromRow) + toCol.getDouble(inColumnRow), inColumnRow); + break; + case SUM_SQUARED: + toCol.replaceDouble( + (double) fromCol.getFloat(fromRow) * fromCol.getFloat(fromRow) + + toCol.getDouble(inColumnRow), + inColumnRow); + break; + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - Objects.requireNonNull(table, "table"); - FloatAggState f = (FloatAggState) state; - if (needsCount) { - f.count = LongMath.checkedAdd(f.count, 1); - } - if (needsStats) { - addFloatStats(table.getFloat(fromColumn, row), f); - } - } - - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addFloatStats(final float value, final FloatAggState state) { - if (needsSum) { - state.sum += value; - } - if (needsSumSq) { - state.sumSquared += value * value; - } - if (needsMin) { - state.min = Math.min(state.min, value); - } - if (needsMax) { - state.max = Math.max(state.max, value); - } + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of( + AggregationOp.COUNT, + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV, + AggregationOp.SUM_SQUARED) + .contains(aggOp); } @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - Objects.requireNonNull(dest, "dest"); - FloatAggState f = (FloatAggState) state; - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case AVG: - dest.putDouble(idx, f.sum * 1.0 / f.count); - break; - case COUNT: - dest.putLong(idx, f.count); - break; - case MAX: - dest.putFloat(idx, f.max); - break; - case MIN: - dest.putFloat(idx, f.min); - break; - case STDEV: - double first = f.sumSquared / f.count; - double second = f.sum / f.count; - double stdev = Math.sqrt(first - second * second); - dest.putDouble(idx, stdev); - break; - case SUM: - dest.putDouble(idx, f.sum); - break; - } - idx++; + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + return Type.LONG_TYPE; + case MAX: + case MIN: + return Type.FLOAT_TYPE; + case SUM: + case AVG: + case STDEV: + return Type.DOUBLE_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); } - } - - @Override - public Type getType() { - return Type.FLOAT_TYPE; - } - - @Override - protected Type getSumType() { - return Type.DOUBLE_TYPE; - } - - @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } + }; @Override - public Object getInitialState() { - return new FloatAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class FloatAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private float min = Float.MAX_VALUE; - /** The maximum value in the aggregated column. */ - private float max = Float.MIN_VALUE; - /** The sum of values in the aggregated column. */ - private double sum = 0; - /** private temp variables for computing stdev. */ - private double sumSquared = 0; + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + data.putLong(column, 0); + break; + case SUM: + case SUM_SQUARED: + data.putDouble(column, 0); + break; + case MAX: + data.putFloat(column, Float.MIN_VALUE); + break; + case MIN: + data.putFloat(column, Float.MAX_VALUE); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/IntegerAggregator.java b/src/edu/washington/escience/myria/operator/agg/IntegerAggregator.java index c01464b0a..e6e873337 100644 --- a/src/edu/washington/escience/myria/operator/agg/IntegerAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/IntegerAggregator.java @@ -1,167 +1,110 @@ package edu.washington.escience.myria.operator.agg; -import java.util.Objects; -import java.util.Set; - import com.google.common.collect.ImmutableSet; import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregate over a set of IntFields. */ public final class IntegerAggregator extends PrimitiveAggregator { + protected IntegerAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); + } + /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for int columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of( - AggregationOp.COUNT, - AggregationOp.SUM, - AggregationOp.MAX, - AggregationOp.MIN, - AggregationOp.AVG, - AggregationOp.STDEV); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public IntegerAggregator( - final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; - } @Override - public void add(final ReadableTable from, final Object state) { - Objects.requireNonNull(from, "from"); - IntAggState istate = (IntAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - - if (needsCount) { - istate.count = LongMath.checkedAdd(istate.count, numTuples); - } - - if (!needsStats) { - return; - } - for (int i = 0; i < numTuples; i++) { - addIntStats(from.getInt(fromColumn, i), istate); + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + toCol.replaceInt(Math.max(fromCol.getInt(fromRow), toCol.getInt(inColumnRow)), inColumnRow); + break; + case MIN: + toCol.replaceInt(Math.min(fromCol.getInt(fromRow), toCol.getInt(inColumnRow)), inColumnRow); + break; + case SUM: + toCol.replaceLong( + LongMath.checkedAdd((long) fromCol.getInt(fromRow), toCol.getLong(inColumnRow)), + inColumnRow); + break; + case SUM_SQUARED: + toCol.replaceLong( + LongMath.checkedAdd( + LongMath.checkedMultiply((long) fromCol.getInt(fromRow), fromCol.getInt(fromRow)), + toCol.getLong(inColumnRow)), + inColumnRow); + break; + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - Objects.requireNonNull(table, "table"); - IntAggState istate = (IntAggState) state; - if (needsCount) { - istate.count = LongMath.checkedAdd(istate.count, 1); - } - if (needsStats) { - addIntStats(table.getInt(fromColumn, row), istate); - } - } - - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addIntStats(final int value, final IntAggState state) { - if (needsSum) { - state.sum = LongMath.checkedAdd(state.sum, value); - } - if (needsSumSq) { - // don't need to check value*value since value is an int - state.sumSquared = LongMath.checkedAdd(state.sumSquared, ((long) value) * value); - } - if (needsMin) { - state.min = Math.min(state.min, value); - } - if (needsMax) { - state.max = Math.max(state.max, value); - } + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of( + AggregationOp.COUNT, + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV, + AggregationOp.SUM_SQUARED) + .contains(aggOp); } @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - Objects.requireNonNull(dest, "dest"); - IntAggState istate = (IntAggState) state; - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case AVG: - dest.putDouble(idx, istate.sum * 1.0 / istate.count); - break; - case COUNT: - dest.putLong(idx, istate.count); - break; - case MAX: - dest.putInt(idx, istate.max); - break; - case MIN: - dest.putInt(idx, istate.min); - break; - case STDEV: - double first = ((double) istate.sumSquared) / istate.count; - double second = ((double) istate.sum) / istate.count; - double stdev = Math.sqrt(first - second * second); - dest.putDouble(idx, stdev); - break; - case SUM: - dest.putLong(idx, istate.sum); - break; - } - idx++; + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + case SUM: + return Type.LONG_TYPE; + case MAX: + case MIN: + return Type.INT_TYPE; + case AVG: + case STDEV: + return Type.DOUBLE_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); } - } - - @Override - public Type getType() { - return Type.INT_TYPE; - } - - @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } - - @Override - protected Type getSumType() { - return Type.LONG_TYPE; - } + }; @Override - public Object getInitialState() { - return new IntAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class IntAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private int min = Integer.MAX_VALUE; - /** The maximum value in the aggregated column. */ - private int max = Integer.MIN_VALUE; - /** The sum of values in the aggregated column. */ - private long sum = 0; - /** private temp variables for computing stdev. */ - private long sumSquared = 0; + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + case SUM: + case SUM_SQUARED: + data.putLong(column, 0); + break; + case MAX: + data.putInt(column, Integer.MIN_VALUE); + break; + case MIN: + data.putInt(column, Integer.MAX_VALUE); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/LongAggregator.java b/src/edu/washington/escience/myria/operator/agg/LongAggregator.java index 1ee649d1e..e5c3fd705 100644 --- a/src/edu/washington/escience/myria/operator/agg/LongAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/LongAggregator.java @@ -1,166 +1,109 @@ package edu.washington.escience.myria.operator.agg; -import java.util.Objects; -import java.util.Set; - import com.google.common.collect.ImmutableSet; import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregates over a LongColumn. */ public final class LongAggregator extends PrimitiveAggregator { + protected LongAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); + } + /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for long columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of( - AggregationOp.COUNT, - AggregationOp.SUM, - AggregationOp.MAX, - AggregationOp.MIN, - AggregationOp.AVG, - AggregationOp.STDEV); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public LongAggregator(final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; - } @Override - public void add(final ReadableTable from, final Object state) { - Objects.requireNonNull(from, "from"); - LongAggState lstate = (LongAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - if (needsCount) { - lstate.count = LongMath.checkedAdd(lstate.count, numTuples); - } - - if (!needsStats) { - return; - } - for (int i = 0; i < numTuples; i++) { - addLongStats(from.getLong(fromColumn, i), lstate); + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + toCol.replaceLong( + Math.max(fromCol.getLong(fromRow), toCol.getLong(inColumnRow)), inColumnRow); + break; + case MIN: + toCol.replaceLong( + Math.min(fromCol.getLong(fromRow), toCol.getLong(inColumnRow)), inColumnRow); + break; + case SUM: + toCol.replaceLong( + LongMath.checkedAdd(fromCol.getLong(fromRow), toCol.getLong(inColumnRow)), inColumnRow); + break; + case SUM_SQUARED: + toCol.replaceLong( + LongMath.checkedAdd( + LongMath.checkedMultiply(fromCol.getLong(fromRow), fromCol.getLong(fromRow)), + toCol.getLong(inColumnRow)), + inColumnRow); + break; + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - Objects.requireNonNull(table, "table"); - LongAggState lstate = (LongAggState) state; - if (needsCount) { - lstate.count = LongMath.checkedAdd(lstate.count, 1); - } - if (needsStats) { - addLongStats(table.getLong(fromColumn, row), lstate); - } - } - - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addLongStats(final long value, final LongAggState state) { - if (needsSum) { - state.sum = LongMath.checkedAdd(state.sum, value); - } - if (needsSumSq) { - state.sumSquared = - LongMath.checkedAdd(state.sumSquared, LongMath.checkedMultiply(value, value)); - } - if (needsMin) { - state.min = Math.min(state.min, value); - } - if (needsMax) { - state.max = Math.max(state.max, value); - } + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of( + AggregationOp.COUNT, + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV, + AggregationOp.SUM_SQUARED) + .contains(aggOp); } @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - Objects.requireNonNull(dest, "dest"); - - LongAggState lstate = (LongAggState) state; - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case AVG: - dest.putDouble(idx, lstate.sum * 1.0 / lstate.count); - break; - case COUNT: - dest.putLong(idx, lstate.count); - break; - case MAX: - dest.putLong(idx, lstate.max); - break; - case MIN: - dest.putLong(idx, lstate.min); - break; - case STDEV: - double first = ((double) lstate.sumSquared) / lstate.count; - double second = ((double) lstate.sum) / lstate.count; - double stdev = Math.sqrt(first - second * second); - dest.putDouble(idx, stdev); - break; - case SUM: - dest.putLong(idx, lstate.sum); - break; - } - idx++; + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + case SUM: + case MAX: + case MIN: + return Type.LONG_TYPE; + case AVG: + case STDEV: + return Type.DOUBLE_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } + }; + + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + case SUM: + case SUM_SQUARED: + data.putLong(column, 0); + break; + case MAX: + data.putLong(column, Long.MIN_VALUE); + break; + case MIN: + data.putLong(column, Long.MAX_VALUE); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); } - } - - @Override - public Type getType() { - return Type.LONG_TYPE; - } - - @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } - - @Override - protected Type getSumType() { - return Type.LONG_TYPE; - } - - @Override - public Object getInitialState() { - return new LongAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class LongAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private long min = Long.MAX_VALUE; - /** The maximum value in the aggregated column. */ - private long max = Long.MIN_VALUE; - /** The sum of values in the aggregated column. */ - private long sum = 0; - /** private temp variables for computing stdev. */ - private long sumSquared = 0; } } diff --git a/src/edu/washington/escience/myria/operator/agg/MultiGroupByAggregate.java b/src/edu/washington/escience/myria/operator/agg/MultiGroupByAggregate.java deleted file mode 100644 index 6b1416499..000000000 --- a/src/edu/washington/escience/myria/operator/agg/MultiGroupByAggregate.java +++ /dev/null @@ -1,351 +0,0 @@ -package edu.washington.escience.myria.operator.agg; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashMap; -import java.util.List; -import java.util.Objects; - -import javax.annotation.Nullable; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; -import com.gs.collections.impl.list.mutable.primitive.IntArrayList; -import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.column.Column; -import edu.washington.escience.myria.operator.Operator; -import edu.washington.escience.myria.operator.UnaryOperator; -import edu.washington.escience.myria.storage.TupleBatch; -import edu.washington.escience.myria.storage.TupleBatchBuffer; -import edu.washington.escience.myria.storage.TupleBuffer; -import edu.washington.escience.myria.storage.TupleUtils; -import edu.washington.escience.myria.util.HashUtils; - -/** - * The Aggregation operator that computes an aggregate (e.g., sum, avg, max, min). This variant supports aggregates over - * multiple columns, group by multiple columns. - * - * @see Aggregate - * @see SingleGroupByAggregate - */ -public final class MultiGroupByAggregate extends UnaryOperator { - /** logger for this class. */ - private static final org.slf4j.Logger LOGGER = - org.slf4j.LoggerFactory.getLogger(MultiGroupByAggregate.class); - - /** Java requires this. **/ - private static final long serialVersionUID = 1L; - - /** Holds the distinct grouping keys. */ - private transient TupleBuffer groupKeys; - /** Final group keys. */ - private List groupKeyList; - /** Holds the corresponding aggregation state for each group key in {@link #groupKeys}. */ - private transient List aggStates; - /** Holds the corresponding TB for each group key in {@link #groupKeys}. */ - private transient List> tbgroupState; - /** Holds the bitset for each group key in {@link #groupKeys}. */ - HashMap bs = new HashMap(); - - /** Maps the hash of a grouping key to a list of indices in {@link #groupKeys}. */ - private transient IntObjectHashMap groupKeyMap; - /** The schema of the columns indicated by the group keys. */ - private Schema groupSchema; - /** The schema of the aggregation result. */ - private Schema aggSchema; - - /** Factories to make the Aggregators. **/ - private final AggregatorFactory[] factories; - /** The actual Aggregators. **/ - private Aggregator[] aggregators; - /** Group fields. **/ - private final int[] gfields; - /** An array [0, 1, .., gfields.length-1] used for comparing tuples. */ - private final int[] grpRange; - /** Input Schema. */ - private Schema inputSchema; - - /** - * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. - * - * @param child The Operator that is feeding us tuples. - * @param gfields The columns over which we are grouping the result. - * @param factories The factories that will produce the {@link Aggregator}s for each group.. - */ - public MultiGroupByAggregate( - @Nullable final Operator child, final int[] gfields, final AggregatorFactory... factories) { - super(child); - this.gfields = Objects.requireNonNull(gfields, "gfields"); - this.factories = Objects.requireNonNull(factories, "factories"); - Preconditions.checkArgument( - gfields.length > 1, "to use MultiGroupByAggregate, must group over multiple fields"); - Preconditions.checkArgument( - factories.length != 0, "to use MultiGroupByAggregate, must specify some aggregates"); - - grpRange = new int[gfields.length]; - for (int i = 0; i < gfields.length; ++i) { - grpRange[i] = i; - } - groupKeyList = null; - } - - @Override - protected void cleanup() throws DbException { - groupKeys = null; - aggStates = null; - groupKeyMap = null; - groupKeyList = null; - } - - /** - * Returns the next tuple. If there is a group by field, then the first field is the field by which we are grouping, - * and the second field is the result of computing the aggregate, If there is no group by field, then the result tuple - * should contain one field representing the result of the aggregate. Should return null if there are no more tuples. - * - * @throws DbException if any error occurs. - * @return result TB. - * @throws IOException - */ - @Override - protected TupleBatch fetchNextReady() throws DbException { - final Operator child = getChild(); - - if (child.eos()) { - return getResultBatch(); - } - - TupleBatch tb = child.nextReady(); - - while (tb != null) { - for (int row = 0; row < tb.numTuples(); ++row) { - int rowHash = HashUtils.hashSubRow(tb, gfields, row); - IntArrayList hashMatches = groupKeyMap.get(rowHash); - if (hashMatches == null) { - hashMatches = newKey(rowHash); - newGroup(tb, row, hashMatches); - continue; - } - boolean found = false; - for (int i = 0; i < hashMatches.size(); i++) { - int value = hashMatches.get(i); - if (TupleUtils.tupleEquals(tb, gfields, row, groupKeys, grpRange, value)) { - addBitSet(row, value); - updateGroup(tb, row, aggStates.get(value)); - found = true; - break; - } - } - - if (!found) { - newGroup(tb, row, hashMatches); - } - - Preconditions.checkState(groupKeys.numTuples() == aggStates.size()); - } - updateGroups(tb); - tb = child.nextReady(); - } - - /* - * We know that child.nextReady() has returned null, so we have processed all tuple we can. Child is - * either EOS or we have to wait for more data. - */ - if (child.eos()) { - return getResultBatch(); - } - - return null; - } - /** - * update groups with tuplebatch values. - * @param tb input tuple batch - */ - private void updateGroups(final TupleBatch tb) { - for (int i = 0; i < tbgroupState.size(); i++) { - if (!bs.get(i).isEmpty()) { - TupleBatch filteredtb = tb.filter(bs.get(i)); - tbgroupState.get(i).add(filteredtb); - } - bs.get(i).clear(); - } - } - /** - * Add bitset for the groupid. - * @param row rowid - * @param groupid groupid of the aggregate. - */ - private void addBitSet(final int row, final int groupid) { - bs.get(groupid).set(row); - } - - /** - * Since row row in {@link TupleBatch} tb does not appear in {@link #groupKeys}, create a - * new group for it. - * - * @param tb the source {@link TupleBatch} - * @param row the row in tb that contains the new group - * @param hashMatches the list of all rows in the output {@link TupleBuffer}s that match this hash. - * @throws DbException if there is an error. - */ - private void newGroup(final TupleBatch tb, final int row, final IntArrayList hashMatches) - throws DbException { - int newIndex = groupKeys.numTuples(); - for (int column = 0; column < gfields.length; ++column) { - TupleUtils.copyValue(tb, gfields[column], row, groupKeys, column); - } - hashMatches.add(newIndex); - Object[] curAggStates = AggUtils.allocateAggStates(aggregators); - aggStates.add(curAggStates); - - // Allocate a tuple batch list to hold state tuples - List ltb = new ArrayList(); - tbgroupState.add(ltb); - - // create a bitset for this tuplebatch - BitSet curbitSet = new BitSet(tb.numTuples()); - bs.put(newIndex, curbitSet); - addBitSet(row, newIndex); - updateGroup(tb, row, curAggStates); - - Preconditions.checkState( - groupKeys.numTuples() == aggStates.size(), - "groupKeys %s != groupAggs %s", - groupKeys.numTuples(), - aggStates.size()); - } - - /** - * Called when there is no list yet of which output aggregators match the specified hash. Creates a new int list to - * store these matches, and insert it into the {@link #groupKeyMap}. - * - * @param groupHash the hash of the grouping columns in a tuple - * @return the new (empty still) int list storing which output aggregators match the specified hash - */ - private IntArrayList newKey(final int groupHash) { - IntArrayList matches = new IntArrayList(1); - groupKeyMap.put(groupHash, matches); - return matches; - } - - /** - * Update the aggregation states with the tuples in the specified row. - * - * @param tb the source {@link TupleBatch} - * @param row the row in tb that contains the new values - * @param curAggStates the aggregation states to be updated. - * @throws DbException if there is an error. - */ - private void updateGroup(final TupleBatch tb, final int row, final Object[] curAggStates) - throws DbException { - - for (int agg = 0; agg < aggregators.length; ++agg) { - if (!(aggregators[agg] instanceof StatefulUserDefinedAggregator)) { - aggregators[agg].addRow(tb, row, curAggStates[agg]); - } - } - } - - /** - * @return A batch's worth of result tuples from this aggregate. - * @throws DbException if there is an error. - */ - private TupleBatch getResultBatch() throws DbException { - Preconditions.checkState( - getChild().eos(), "cannot extract results from an aggregate until child has reached EOS"); - if (groupKeyList == null) { - groupKeyList = Lists.newLinkedList(groupKeys.finalResult()); - groupKeys = null; - } - if (groupKeyList.isEmpty()) { - return null; - } - TupleBatch curGroupKeys = groupKeyList.remove(0); - TupleBatchBuffer curGroupAggs = new TupleBatchBuffer(aggSchema); - for (int row = 0; row < curGroupKeys.numTuples(); ++row) { - Object[] rowAggs = aggStates.get(row); - List lt = tbgroupState.get(row); - int curCol = 0; - for (int agg = 0; agg < aggregators.length; ++agg) { - if ((aggregators[agg] instanceof StatefulUserDefinedAggregator)) { - aggregators[agg].add(lt); - } - aggregators[agg].getResult(curGroupAggs, curCol, rowAggs[agg]); - - curCol += aggregators[agg].getResultSchema().numColumns(); - } - } - TupleBatch aggResults = curGroupAggs.popAny(); - Preconditions.checkState( - curGroupKeys.numTuples() == aggResults.numTuples(), - "curGroupKeys size %s != aggResults size %s", - curGroupKeys.numTuples(), - aggResults.numTuples()); - - /* Note: as of Java7 sublists of sublists do what we want -- the sublists are at most one deep. */ - tbgroupState = tbgroupState.subList(curGroupKeys.numTuples(), tbgroupState.size()); - aggStates = aggStates.subList(curGroupKeys.numTuples(), aggStates.size()); - - return new TupleBatch( - getSchema(), - ImmutableList.>builder() - .addAll(curGroupKeys.getDataColumns()) - .addAll(aggResults.getDataColumns()) - .build()); - } - - /** - * The schema of the aggregate output. Grouping fields first and then aggregate fields. The aggregate - * - * @return the resulting schema - */ - @Override - protected Schema generateSchema() { - Operator child = getChild(); - if (child == null) { - return null; - } - inputSchema = child.getSchema(); - if (inputSchema == null) { - return null; - } - - groupSchema = inputSchema.getSubSchema(gfields); - - /* Build the output schema from the group schema and the aggregates. */ - final ImmutableList.Builder aggTypes = ImmutableList.builder(); - final ImmutableList.Builder aggNames = ImmutableList.builder(); - - try { - for (Aggregator agg : AggUtils.allocateAggs(factories, getChild().getSchema(), null)) { - Schema curAggSchema = agg.getResultSchema(); - aggTypes.addAll(curAggSchema.getColumnTypes()); - aggNames.addAll(curAggSchema.getColumnNames()); - } - - } catch (DbException e) { - throw new RuntimeException("unable to allocate aggregators to determine output schema", e); - } - aggSchema = new Schema(aggTypes, aggNames); - return Schema.merge(groupSchema, aggSchema); - } - - @Override - protected void init(final ImmutableMap execEnvVars) throws DbException { - Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); - inputSchema = getChild().getSchema(); - aggregators = - AggUtils.allocateAggs(factories, getChild().getSchema(), getPythonFunctionRegistrar()); - groupKeys = new TupleBuffer(groupSchema, TupleUtils.getBatchSize(generateSchema())); - aggStates = new ArrayList<>(); - tbgroupState = new ArrayList<>(); - bs = new HashMap(); - groupKeyMap = new IntObjectHashMap<>(); - } -}; diff --git a/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregator.java b/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregator.java index 1f4941189..edd178ea0 100644 --- a/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregator.java @@ -1,24 +1,14 @@ package edu.washington.escience.myria.operator.agg; import java.io.Serializable; -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Objects; -import java.util.Set; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.AppendableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; /** * Single column aggregator. */ -@SuppressWarnings("checkstyle:visibilitymodifier") public abstract class PrimitiveAggregator implements Aggregator, Serializable { /** Required for Java serialization. */ @@ -34,41 +24,24 @@ public enum AggregationOp { MIN, /** MAX. Applies to all types. Result is same as input type. */ MAX, - /** - * SUM. Applies to numeric types. Result is the bigger numeric type, i.e., {@link Type#INT_TYPE} -> - * {@link Type#LONG_TYPE} and . {@link Type#FLOAT_TYPE} -> {@link Type#DOUBLE_TYPE}. - */ + /** SUM. Applies to numeric types. Result is coerced to the largest compatible numeric type (long or double). */ SUM, /** AVG. Applies to numeric types. Result is always {@link Type#DOUBLE_TYPE}. */ AVG, /** STDEV. Applies to numeric types. Result is always {@link Type#DOUBLE_TYPE}. */ - STDEV + STDEV, + /** + * SUM_SQUARED. Applies to numeric types. Result is coerced to the largest compatible numeric type (long or double). + */ + SUM_SQUARED }; - /** Does this aggregator need to compute the count? */ - protected final boolean needsCount; - /** Does this aggregator need to compute the sum? */ - protected final boolean needsSum; - /** Does this aggregator need to compute the sum squared? */ - protected final boolean needsSumSq; - /** Does this aggregator need to compute the max? */ - protected final boolean needsMax; - /** Does this aggregator need to compute the min? */ - protected final boolean needsMin; - /** Does this aggregator need to compute tuple-level stats? */ - protected final boolean needsStats; - /** - * Aggregate operations. A set of all valid aggregation operations, i.e. those in {@link LongAggregator#AVAILABLE_AGG} - * . - * - * Note that we use a {@link LinkedHashSet} to ensure that the iteration order is consistent! - */ - protected final LinkedHashSet aggOps; - - /** - * Result schema. It's automatically generated according to the {@link #aggOps}. - */ - private final Schema resultSchema; + /** The aggregate operation. */ + protected final AggregationOp aggOp; + /** The column to aggregate on. */ + protected final int column; + /** The output name of the aggregate. */ + private final String outputName; /** * Instantiate a PrimitiveAggregator that computes the specified aggregates. @@ -76,87 +49,40 @@ public enum AggregationOp { * @param fieldName the name of the field being aggregated, for naming output columns. * @param aggOps the set of aggregate operations to be computed. */ - protected PrimitiveAggregator(final String fieldName, final AggregationOp[] aggOps) { - Objects.requireNonNull(aggOps, "aggOps"); - Objects.requireNonNull(fieldName, "fieldName"); - - this.aggOps = new LinkedHashSet<>(Arrays.asList(aggOps)); - - if (!getAvailableAgg().containsAll(this.aggOps)) { - throw new IllegalArgumentException( - "Unsupported aggregation(s): " + Sets.difference(this.aggOps, getAvailableAgg())); + protected PrimitiveAggregator( + final String inputName, final int column, final AggregationOp aggOp) { + if (!isSupported(aggOp)) { + throw new IllegalArgumentException("Unsupported aggregation " + aggOp); } + this.aggOp = aggOp; + this.column = column; + this.outputName = aggOp.toString().toLowerCase() + "_" + inputName; + } - if (aggOps.length == 0) { - throw new IllegalArgumentException("No aggregation operations are selected"); - } - - needsCount = AggUtils.needsCount(this.aggOps); - needsSum = AggUtils.needsSum(this.aggOps); - needsSumSq = AggUtils.needsSumSq(this.aggOps); - needsMin = AggUtils.needsMin(this.aggOps); - needsMax = AggUtils.needsMax(this.aggOps); - needsStats = AggUtils.needsStats(this.aggOps); - - final ImmutableList.Builder types = ImmutableList.builder(); - final ImmutableList.Builder names = ImmutableList.builder(); - for (AggregationOp op : this.aggOps) { - switch (op) { - case COUNT: - types.add(Type.LONG_TYPE); - names.add("count_" + fieldName); - break; - case MAX: - types.add(getType()); - names.add("max_" + fieldName); - break; - case MIN: - types.add(getType()); - names.add("min_" + fieldName); - break; - case AVG: - types.add(Type.DOUBLE_TYPE); - names.add("avg_" + fieldName); - break; - case STDEV: - types.add(Type.DOUBLE_TYPE); - names.add("stdev_" + fieldName); - break; - case SUM: - types.add(getSumType()); - names.add("sum_" + fieldName); - break; - } - } - resultSchema = new Schema(types, names); + @Override + public int getStateSize() { + return 1; } - /** - * Returns the Type of the SUM aggregate. - * - * @return the Type of the SUM aggregate. - */ - protected abstract Type getSumType(); + /** @return The {@link Type} of the values this aggregator handles. */ + protected abstract Type getOutputType(); /** - * Returns the set of aggregation operations that are supported by this aggregator. + * Initialize a state by appending an initial value to a column. * - * @return the set of aggregation operations that are supported by this aggregator. + * @param data the table to append to + * @param column the column to append to */ - protected abstract Set getAvailableAgg(); + protected abstract void appendInitValue(AppendableTable data, final int column); /** - * @return The {@link Type} of the values this aggregator handles. + * @param aggOp + * @return if aggOp is supported by this aggregator. */ - public abstract Type getType(); - - @Override - public final Schema getResultSchema() { - return resultSchema; - } + protected abstract boolean isSupported(AggregationOp aggOp); @Override - public void add(final List from) throws DbException { - throw new DbException(" method not implemented"); + public void initState(final MutableTupleBuffer state, final int offset) { + appendInitValue(state, offset); } } diff --git a/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregatorFactory.java b/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregatorFactory.java new file mode 100644 index 000000000..ada4e009c --- /dev/null +++ b/src/edu/washington/escience/myria/operator/agg/PrimitiveAggregatorFactory.java @@ -0,0 +1,233 @@ +package edu.washington.escience.myria.operator.agg; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.expression.DivideExpression; +import edu.washington.escience.myria.expression.Expression; +import edu.washington.escience.myria.expression.ExpressionOperator; +import edu.washington.escience.myria.expression.MinusExpression; +import edu.washington.escience.myria.expression.SqrtExpression; +import edu.washington.escience.myria.expression.TimesExpression; +import edu.washington.escience.myria.expression.VariableExpression; +import edu.washington.escience.myria.functions.PythonFunctionRegistrar; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; + +/** + * A factory that generates aggregators for a primitive column. + */ +public class PrimitiveAggregatorFactory implements AggregatorFactory { + + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + /** Which column of the input to aggregate over. */ + @JsonProperty private final int column; + /** Which aggregate options are requested. See {@link PrimitiveAggregator}. */ + @JsonProperty private final AggregationOp[] aggOps; + + /** + * A wrapper for the {@link PrimitiveAggregator} implementations like {@link IntegerAggregator}. + * + * @param column which column of the input to aggregate over. + * @param aggOps which aggregate operations are requested. See {@link PrimitiveAggregator}. + */ + @JsonCreator + public PrimitiveAggregatorFactory( + @JsonProperty(value = "column", required = true) final Integer column, + @JsonProperty(value = "aggOps", required = true) final AggregationOp[] aggOps) { + this.column = Objects.requireNonNull(column, "column").intValue(); + this.aggOps = Objects.requireNonNull(aggOps, "aggOps"); + Preconditions.checkNotNull(aggOps, "aggregation operator %s cannot be null"); + } + + /** + * @param column which column of the input to aggregate over. + * @param aggOp which aggregate is requested. + */ + public PrimitiveAggregatorFactory(final Integer column, final AggregationOp aggOp) { + this(column, new AggregationOp[] {aggOp}); + } + + @Override + public List generateInternalAggs(final Schema inputSchema) { + List ret = new ArrayList(); + List ops = getInternalOps(); + for (int i = 0; i < ops.size(); ++i) { + ret.add(generateAgg(inputSchema, ops.get(i))); + } + return ret; + } + + /** + * @param inputSchema the input schema + * @param aggOp the aggregation op + * @param indices the column indices of this aggregator in the state hash table + * @return the generated aggregator + */ + private Aggregator generateAgg(final Schema inputSchema, final AggregationOp aggOp) { + String inputName = inputSchema.getColumnName(column); + Type type = inputSchema.getColumnType(column); + switch (type) { + case BOOLEAN_TYPE: + return new BooleanAggregator(inputName, column, aggOp); + case DATETIME_TYPE: + return new DateTimeAggregator(inputName, column, aggOp); + case DOUBLE_TYPE: + return new DoubleAggregator(inputName, column, aggOp); + case FLOAT_TYPE: + return new FloatAggregator(inputName, column, aggOp); + case INT_TYPE: + return new IntegerAggregator(inputName, column, aggOp); + case LONG_TYPE: + return new LongAggregator(inputName, column, aggOp); + case STRING_TYPE: + return new StringAggregator(inputName, column, aggOp); + default: + throw new IllegalArgumentException("Unknown column type: " + type); + } + } + + @Override + public List generateEmitExpressions(final Schema inputSchema) { + List cols = getInternalOps(); + List exps = new ArrayList(); + for (int i = 0; i < aggOps.length; ++i) { + String name = aggOps[i].toString().toLowerCase() + "_" + inputSchema.getColumnName(column); + switch (aggOps[i]) { + case COUNT: + case MIN: + case MAX: + case SUM: + exps.add(new Expression(name, new VariableExpression(cols.indexOf(aggOps[i])))); + continue; + case AVG: + exps.add( + new Expression( + name, + new DivideExpression( + new VariableExpression(cols.indexOf(AggregationOp.SUM)), + new VariableExpression(cols.indexOf(AggregationOp.COUNT))))); + continue; + case STDEV: + ExpressionOperator sumExp = new VariableExpression(cols.indexOf(AggregationOp.SUM)); + ExpressionOperator countExp = new VariableExpression(cols.indexOf(AggregationOp.COUNT)); + ExpressionOperator sumSquaredExp = + new VariableExpression(cols.indexOf(AggregationOp.SUM_SQUARED)); + ExpressionOperator first = new DivideExpression(sumSquaredExp, countExp); + ExpressionOperator second = new DivideExpression(sumExp, countExp); + exps.add( + new Expression( + name, + new SqrtExpression( + new MinusExpression(first, new TimesExpression(second, second))))); + continue; + default: + throw new IllegalArgumentException("Type " + aggOps[i] + " is invalid"); + } + } + return exps; + } + + /** + * Generate the internal aggregation ops. Each used op corresponds to one column. + * + * @return the list of aggregation ops. + */ + public List getInternalOps() { + Set colTypes = new HashSet(); + for (AggregationOp aggOp : aggOps) { + colTypes.addAll(getInternalOps(aggOp)); + } + List ret = new ArrayList(colTypes); + Collections.sort(ret); + return ret; + } + + /** + * @param op the emit aggregation op + * @return the internal aggregation ops needed for computing the emit op + */ + private List getInternalOps(AggregationOp op) { + switch (op) { + case COUNT: + case MIN: + case MAX: + case SUM: + return ImmutableList.of(op); + case AVG: + return ImmutableList.of(AggregationOp.SUM, AggregationOp.COUNT); + case STDEV: + return ImmutableList.of(AggregationOp.SUM, AggregationOp.SUM_SQUARED, AggregationOp.COUNT); + default: + throw new IllegalArgumentException("Type " + op + " is invalid"); + } + } + + /** + * @param input the input type + * @param op the aggregation op + * @return the output type of applying op on the input type + */ + public Type getAggColumnType(Type input, AggregationOp op) { + switch (op) { + case MIN: + case MAX: + return input; + case COUNT: + return Type.LONG_TYPE; + case SUM: + case SUM_SQUARED: + if (input == Type.INT_TYPE || input == Type.LONG_TYPE) { + return Type.LONG_TYPE; + } + if (input == Type.FLOAT_TYPE || input == Type.DOUBLE_TYPE) { + return Type.DOUBLE_TYPE; + } + throw new IllegalArgumentException("SUM_SQUARED on " + op + " is invalid"); + case AVG: + case STDEV: + return Type.DOUBLE_TYPE; + default: + throw new IllegalArgumentException(op + " is invalid"); + } + } + + @Override + public Schema generateSchema(final Schema inputSchema) { + List names = new ArrayList(); + List types = new ArrayList(); + for (AggregationOp op : aggOps) { + types.add(getAggColumnType(inputSchema.getColumnType(column), op)); + names.add(op.toString().toLowerCase() + "_" + inputSchema.getColumnName(column)); + } + return Schema.of(types, names); + } + + @Override + public Schema generateStateSchema(final Schema inputSchema) { + List names = new ArrayList(); + List types = new ArrayList(); + for (AggregationOp op : getInternalOps()) { + types.add(getAggColumnType(inputSchema.getColumnType(column), op)); + names.add(op.toString().toLowerCase() + "_" + inputSchema.getColumnName(column)); + } + return Schema.of(types, names); + } + + PythonFunctionRegistrar pyFuncReg; + + public void setPyFuncReg(PythonFunctionRegistrar pyFuncReg) { + this.pyFuncReg = pyFuncReg; + } +} diff --git a/src/edu/washington/escience/myria/operator/agg/SingleColumnAggregatorFactory.java b/src/edu/washington/escience/myria/operator/agg/SingleColumnAggregatorFactory.java deleted file mode 100644 index 5c6cea914..000000000 --- a/src/edu/washington/escience/myria/operator/agg/SingleColumnAggregatorFactory.java +++ /dev/null @@ -1,75 +0,0 @@ -package edu.washington.escience.myria.operator.agg; - -import java.util.Objects; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Preconditions; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.functions.PythonFunctionRegistrar; -import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; - -/** - * An aggregator for a column of primitive type. - */ -public class SingleColumnAggregatorFactory implements AggregatorFactory { - - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - /** Which column of the input to aggregate over. */ - @JsonProperty private final int column; - /** Which aggregate options are requested. See {@link PrimitiveAggregator}. */ - @JsonProperty private final AggregationOp[] aggOps; - - /** - * A wrapper for the {@link PrimitiveAggregator} implementations like {@link IntegerAggregator}. - * - * @param column which column of the input to aggregate over. - * @param aggOps which aggregate operations are requested. See {@link PrimitiveAggregator}. - */ - @JsonCreator - public SingleColumnAggregatorFactory( - @JsonProperty(value = "column", required = true) final Integer column, - @JsonProperty(value = "aggOps", required = true) final AggregationOp... aggOps) { - this.column = Objects.requireNonNull(column, "column").intValue(); - this.aggOps = Objects.requireNonNull(aggOps, "aggOps"); - Preconditions.checkArgument(aggOps.length > 0, "no aggregation operators selected"); - for (int i = 0; i < aggOps.length; ++i) { - Preconditions.checkNotNull(aggOps[i], "aggregation operator %s cannot be null", i); - } - } - - @Override - public Aggregator get(final Schema inputSchema) { - Objects.requireNonNull(inputSchema, "inputSchema"); - Objects.requireNonNull(aggOps, "aggOps"); - String inputName = inputSchema.getColumnName(column); - Type type = inputSchema.getColumnType(column); - switch (type) { - case BOOLEAN_TYPE: - return new BooleanAggregator(inputName, aggOps, column); - case DATETIME_TYPE: - return new DateTimeAggregator(inputName, aggOps, column); - case DOUBLE_TYPE: - return new DoubleAggregator(inputName, aggOps, column); - case FLOAT_TYPE: - return new FloatAggregator(inputName, aggOps, column); - case INT_TYPE: - return new IntegerAggregator(inputName, aggOps, column); - case LONG_TYPE: - return new LongAggregator(inputName, aggOps, column); - case STRING_TYPE: - return new StringAggregator(inputName, aggOps, column); - } - throw new IllegalArgumentException("Unknown column type: " + type); - } - - @Override - public Aggregator get(final Schema inputSchema, final PythonFunctionRegistrar pyFuncReg) - throws DbException { - return get(inputSchema); - } -} diff --git a/src/edu/washington/escience/myria/operator/agg/SingleGroupByAggregate.java b/src/edu/washington/escience/myria/operator/agg/SingleGroupByAggregate.java deleted file mode 100644 index dee18262e..000000000 --- a/src/edu/washington/escience/myria/operator/agg/SingleGroupByAggregate.java +++ /dev/null @@ -1,641 +0,0 @@ -package edu.washington.escience.myria.operator.agg; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import javax.annotation.Nullable; - -import org.joda.time.DateTime; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import com.gs.collections.impl.map.mutable.primitive.DoubleObjectHashMap; -import com.gs.collections.impl.map.mutable.primitive.FloatObjectHashMap; -import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; -import com.gs.collections.impl.map.mutable.primitive.LongObjectHashMap; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.operator.Operator; -import edu.washington.escience.myria.operator.UnaryOperator; -import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.TupleBatch; -import edu.washington.escience.myria.storage.TupleBatchBuffer; - -/** - * The Aggregation operator that computes an aggregate (e.g., sum, avg, max, min) with a single group by column. - */ -public class SingleGroupByAggregate extends UnaryOperator { - - /** - * The Logger. - */ - private static final org.slf4j.Logger LOGGER = - org.slf4j.LoggerFactory.getLogger(SingleGroupByAggregate.class); - - /** - * default serialization ID. - */ - private static final long serialVersionUID = 1L; - - /** - * Factories to create the {@link Aggregator}s. - */ - private final AggregatorFactory[] factories; - - /** - * The group by column. - */ - private final int gColumn; - - /** - * A cache of the group-by column type. - */ - private Type gColumnType; - - /** - * The buffer storing in-progress group by results. {groupby-column-value -> Aggregator Array} when the group key is - * String - */ - private transient HashMap stringAggState; - - /** - * The buffer storing in-progress group by results. {groupby-column-value -> Aggregator Array} when the group key is - * DateTime. - */ - private transient HashMap datetimeAggState; - - /** - * The buffer storing in-progress group by results when the group key is int. - */ - private transient IntObjectHashMap intAggState; - /** - * The buffer storing in-progress group by results when the group key is boolean. - */ - private transient Object[][] booleanAggState; - /** - * The buffer storing in-progress group by results when the group key is long. - */ - private transient LongObjectHashMap longAggState; - /** - * The buffer storing in-progress group by results when the group key is float. - */ - private transient FloatObjectHashMap floatAggState; - /** - * The buffer storing in-progress group by results when the group key is double. - */ - private transient DoubleObjectHashMap doubleAggState; - /** - * HashMap containing the tuplebatches. - */ - private transient HashMap> ltb; - /** - * Hashmap containing the bitset. - */ - private transient HashMap tbbs; - /** - * The aggregators that will initialize and update the state. - */ - private Aggregator[] aggregators; - - /** - * The buffer storing results after group by is done. - */ - private transient TupleBatchBuffer resultBuffer; - - /** - * Constructor. - * - * @param child The Operator that is feeding us tuples. - * @param gfield The column over which we are grouping the result. - * @param factories Factories for the aggregation operators to use. - */ - public SingleGroupByAggregate( - @Nullable final Operator child, final int gfield, final AggregatorFactory... factories) { - super(child); - gColumn = Objects.requireNonNull(gfield, "gfield"); - this.factories = Objects.requireNonNull(factories, "factories"); - } - - @Override - protected final void cleanup() throws DbException { - stringAggState = null; - datetimeAggState = null; - doubleAggState = null; - booleanAggState = null; - floatAggState = null; - intAggState = null; - longAggState = null; - resultBuffer = null; - } - - /** - * Utility function to fetch or create/initialize the aggregation state for the group corresponding to the data in the - * specified table and row. - * - * @param table the data to be aggregated. - * @param row which row of the table is to be aggregated. - * @return the aggregation state for that row. - * @throws DbException if there is an error. - */ - private Object[] getAggState(final ReadableTable table, final int row) throws DbException { - Object[] aggState = null; - switch (gColumnType) { - case BOOLEAN_TYPE: - boolean groupByBool = table.getBoolean(gColumn, row); - if (groupByBool) { - aggState = booleanAggState[0]; - } else { - aggState = booleanAggState[1]; - } - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - if (groupByBool) { - booleanAggState[0] = aggState; - } else { - booleanAggState[1] = aggState; - } - } - break; - case STRING_TYPE: - String groupByString = table.getString(gColumn, row); - aggState = stringAggState.get(groupByString); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - stringAggState.put(groupByString, aggState); - } - break; - case DATETIME_TYPE: - DateTime groupByDateTime = table.getDateTime(gColumn, row); - aggState = datetimeAggState.get(groupByDateTime); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - datetimeAggState.put(groupByDateTime, aggState); - } - break; - case INT_TYPE: - int groupByInt = table.getInt(gColumn, row); - aggState = intAggState.get(groupByInt); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - intAggState.put(groupByInt, aggState); - } - break; - case LONG_TYPE: - long groupByLong = table.getLong(gColumn, row); - aggState = longAggState.get(groupByLong); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - longAggState.put(groupByLong, aggState); - } - break; - case FLOAT_TYPE: - float groupByFloat = table.getFloat(gColumn, row); - aggState = floatAggState.get(groupByFloat); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - floatAggState.put(groupByFloat, aggState); - } - break; - case DOUBLE_TYPE: - double groupByDouble = table.getDouble(gColumn, row); - aggState = doubleAggState.get(groupByDouble); - if (aggState == null) { - aggState = AggUtils.allocateAggStates(aggregators); - doubleAggState.put(groupByDouble, aggState); - } - break; - } - if (aggState == null) { - throw new IllegalStateException("Aggregating values of unknown type."); - } - return aggState; - } - - /** - * @param tb the TupleBatch to be processed. - * @throws DbException if there is an error. - */ - private void processTupleBatch(final TupleBatch tb) throws DbException { - for (int agg = 0; agg < aggregators.length; ++agg) { - - for (int i = 0; i < tb.numTuples(); ++i) { - Object[] groupAgg = getAggState(tb, i); - if (aggregators[agg] instanceof StatefulUserDefinedAggregator) { - setBitSet(tb, i); - } else { - aggregators[agg].addRow(tb, i, groupAgg[agg]); - } - } - if (aggregators[agg] instanceof StatefulUserDefinedAggregator) { - updateltbGroups(tb); - } - } - } - /** - * Update the list of tuple batch. - * @param table tb to be added. - * @throws DbException if there is an error. - */ - private void updateltbGroups(final TupleBatch table) throws DbException { - - switch (gColumnType) { - case BOOLEAN_TYPE: - for (int key = 0; key < 2; key++) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case STRING_TYPE: - for (String key : stringAggState.keySet()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case DATETIME_TYPE: - for (DateTime key : datetimeAggState.keySet()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case INT_TYPE: - for (Integer key : intAggState.keySet().toArray()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case LONG_TYPE: - for (Long key : longAggState.keySet().toArray()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case FLOAT_TYPE: - for (Float key : floatAggState.keySet().toArray()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - case DOUBLE_TYPE: - for (Double key : doubleAggState.keySet().toArray()) { - BitSet bs = tbbs.get(key); - if (bs != null && !(bs.isEmpty())) { - - List listTb = ltb.get(key); - if (listTb == null) { - List nlTb = new ArrayList(); - nlTb.add(table.filter(bs)); - ltb.put(key, nlTb); - } else { - listTb.add(table.filter(bs)); - } - bs.clear(); - } - } - break; - default: - throw new DbException("type not supported for SingleColumnGroupby"); - } - } - - /** - * Private method to update bitset. - * @param table tb containing the tuple. - * @param row row to be update. - * @throws DbException in case of error - */ - private void setBitSet(final ReadableTable table, final int row) throws DbException { - - BitSet bs; - switch (gColumnType) { - case BOOLEAN_TYPE: - boolean groupByBool = table.getBoolean(gColumn, row); - bs = tbbs.get(groupByBool); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByBool, bs); - } else { - bs.set(row); - } - break; - case STRING_TYPE: - String groupByString = table.getString(gColumn, row); - bs = tbbs.get(groupByString); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByString, bs); - } else { - bs.set(row); - } - break; - case DATETIME_TYPE: - DateTime groupByDateTime = table.getDateTime(gColumn, row); - bs = tbbs.get(groupByDateTime); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByDateTime, bs); - } else { - bs.set(row); - } - break; - case INT_TYPE: - int groupByInt = table.getInt(gColumn, row); - bs = tbbs.get(groupByInt); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByInt, bs); - } else { - bs.set(row); - } - break; - case LONG_TYPE: - long groupByLong = table.getLong(gColumn, row); - bs = tbbs.get(groupByLong); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByLong, bs); - } else { - bs.set(row); - } - break; - case FLOAT_TYPE: - float groupByFloat = table.getFloat(gColumn, row); - bs = tbbs.get(groupByFloat); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByFloat, bs); - } else { - bs.set(row); - } - break; - case DOUBLE_TYPE: - double groupByDouble = table.getDouble(gColumn, row); - bs = tbbs.get(groupByDouble); - if (bs == null) { - bs = new BitSet(table.numTuples()); - bs.set(row); - tbbs.put(groupByDouble, bs); - } else { - bs.set(row); - } - break; - default: - throw new DbException("type not supported for SingleColumnGroupby"); - } - } - - /** - * Helper function for appending results to an output tuple buffer. By convention, the single-column aggregation key - * goes in column 0, and the aggregates are appended starting at column 1. - * - * @param resultBuffer where the tuples will be appended. - * @param aggState the states corresponding to all aggregators. - * @throws DbException if there is an error. - * @throws IOException - */ - private void concatResults( - final TupleBatchBuffer resultBuffer, final Object[] aggState, final Object key) - throws DbException, IOException { - int index = 1; - for (int agg = 0; agg < aggregators.length; ++agg) { - - if (aggregators[agg] instanceof StatefulUserDefinedAggregator) { - - List listTb = ltb.get(key); - if (listTb.size() > 0) { - aggregators[agg].add(listTb); - } - } - aggregators[agg].getResult(resultBuffer, index, aggState[agg]); - index += aggregators[agg].getResultSchema().numColumns(); - } - } - - /** - * @param resultBuffer where the results are stored. - * @throws DbException if there is an error. - * @throws IOException if there is an error. - */ - private void generateResult(final TupleBatchBuffer resultBuffer) throws DbException, IOException { - - switch (gColumnType) { - case BOOLEAN_TYPE: - for (int boolBucket = 0; boolBucket < 2; ++boolBucket) { - Object[] aggState = booleanAggState[boolBucket]; - if (aggState != null) { - /* True is index 0 in booleanAggState, False is index 1. */ - resultBuffer.putBoolean(0, boolBucket == 0); - concatResults(resultBuffer, aggState, boolBucket); - } - } - break; - case STRING_TYPE: - for (final Map.Entry e : stringAggState.entrySet()) { - resultBuffer.putString(0, e.getKey()); - concatResults(resultBuffer, e.getValue(), e.getKey()); - } - break; - case DATETIME_TYPE: - for (final Map.Entry e : datetimeAggState.entrySet()) { - resultBuffer.putDateTime(0, e.getKey()); - concatResults(resultBuffer, e.getValue(), e.getKey()); - } - break; - case INT_TYPE: - for (int key : intAggState.keySet().toArray()) { - resultBuffer.putInt(0, key); - concatResults(resultBuffer, intAggState.get(key), key); - } - break; - case LONG_TYPE: - for (long key : longAggState.keySet().toArray()) { - resultBuffer.putLong(0, key); - concatResults(resultBuffer, longAggState.get(key), key); - } - break; - case FLOAT_TYPE: - for (float key : floatAggState.keySet().toArray()) { - resultBuffer.putFloat(0, key); - concatResults(resultBuffer, floatAggState.get(key), key); - } - break; - case DOUBLE_TYPE: - for (double key : doubleAggState.keySet().toArray()) { - resultBuffer.putDouble(0, key); - concatResults(resultBuffer, doubleAggState.get(key), key); - } - break; - } - } - - @Override - protected final TupleBatch fetchNextReady() throws DbException, IOException { - TupleBatch tb = null; - final Operator child = getChild(); - - if (resultBuffer.numTuples() > 0) { - return resultBuffer.popAny(); - } - - if (child.eos()) { - return null; - } - - while ((tb = child.nextReady()) != null) { - - processTupleBatch(tb); - } - - if (child.eos()) { - generateResult(resultBuffer); - } - return resultBuffer.popAny(); - } - - /** - * @return the group by column. - */ - public final int getGroupByColumn() { - return gColumn; - } - - @Override - protected final void init(final ImmutableMap execEnvVars) throws DbException { - Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); - - aggregators = - AggUtils.allocateAggs(factories, getChild().getSchema(), getPythonFunctionRegistrar()); - resultBuffer = new TupleBatchBuffer(getSchema()); - ltb = new HashMap>(); - tbbs = new HashMap(); - - switch (gColumnType) { - case BOOLEAN_TYPE: - booleanAggState = new Object[2][]; - break; - case INT_TYPE: - intAggState = new IntObjectHashMap(); - break; - case LONG_TYPE: - longAggState = new LongObjectHashMap(); - break; - case FLOAT_TYPE: - floatAggState = new FloatObjectHashMap(); - break; - case DOUBLE_TYPE: - doubleAggState = new DoubleObjectHashMap(); - break; - case STRING_TYPE: - stringAggState = new HashMap(); - break; - case DATETIME_TYPE: - datetimeAggState = new HashMap(); - break; - } - } - - @Override - protected Schema generateSchema() { - Operator child = getChild(); - if (child == null) { - return null; - } - Schema inputSchema = child.getSchema(); - if (inputSchema == null) { - return null; - } - - Preconditions.checkElementIndex(gColumn, inputSchema.numColumns(), "group column"); - - Schema outputSchema = - Schema.ofFields(inputSchema.getColumnType(gColumn), inputSchema.getColumnName(gColumn)); - - gColumnType = inputSchema.getColumnType(gColumn); - try { - for (Aggregator a : AggUtils.allocateAggs(factories, inputSchema, null)) { - outputSchema = Schema.merge(outputSchema, a.getResultSchema()); - } - } catch (DbException e) { - throw new RuntimeException("unable to allocate aggregators to determine output schema", e); - } - return outputSchema; - } -} diff --git a/src/edu/washington/escience/myria/operator/agg/StatefulUserDefinedAggregator.java b/src/edu/washington/escience/myria/operator/agg/StatefulUserDefinedAggregator.java deleted file mode 100644 index d074588a5..000000000 --- a/src/edu/washington/escience/myria/operator/agg/StatefulUserDefinedAggregator.java +++ /dev/null @@ -1,120 +0,0 @@ -/** - * - */ -package edu.washington.escience.myria.operator.agg; - -import java.io.IOException; -import java.lang.reflect.InvocationTargetException; -import java.util.ArrayList; -import java.util.List; - -import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; -import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; -import edu.washington.escience.myria.expression.evaluate.ScriptEvalInterface; -import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.Tuple; -import edu.washington.escience.myria.storage.TupleBatch; - -/** - * - */ -public class StatefulUserDefinedAggregator extends UserDefinedAggregator { - - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - /** logger for this class. */ - private static final org.slf4j.Logger LOGGER = - org.slf4j.LoggerFactory.getLogger(StatefulUserDefinedAggregator.class); - - /** - * list for holding state. - */ - private List ltb; - /** - * @param state the initialized state of the tuple - * @param updateEvaluator updates the state given an input row - * @param pyUDFEvaluators python expression evaluators. - * @param emitEvaluators the evaluators that finalize the state - * @param resultSchema the schema of the tuples produced by this aggregator - */ - public StatefulUserDefinedAggregator( - final Tuple state, - final ScriptEvalInterface updateEvaluator, - final List pyUDFEvaluators, - final List emitEvaluators, - final Schema resultSchema) { - super(state, updateEvaluator, pyUDFEvaluators, emitEvaluators, resultSchema); - } - - @Override - public void add(final ReadableTable from, final Object state) throws DbException { - throw new DbException(" method not implemented"); - } - - @Override - public void add(final List from) throws DbException { - ltb = from; - } - - @Override - public void addRow(final ReadableTable from, final int row, final Object state) - throws DbException { - Tuple stateTuple = (Tuple) state; - try { - if (updateEvaluator != null) { - updateEvaluator.evaluate(from, row, stateTuple, stateTuple); - } - } catch (Exception e) { - - throw new DbException("Error updating UDA state", e); - } - } - - @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) - throws DbException { - - Tuple stateTuple = (Tuple) state; - - // compute results over the tuplebatch list - if (pyUDFEvaluators.size() > 0) { - for (int i = 0; i < pyUDFEvaluators.size(); i++) { - - if (ltb != null && ltb.size() != 0) { - try { - pyUDFEvaluators.get(i).evalBatch(ltb, stateTuple, stateTuple); - } catch (Exception e) { - throw new DbException(e); - } - - } else { - throw new DbException("cannot get results!!"); - } - } - } - // emit results - - for (int index = 0; index < emitEvaluators.size(); index++) { - final GenericEvaluator evaluator = emitEvaluators.get(index); - try { - evaluator.eval(null, 0, null, dest.asWritableColumn(destColumn + index), stateTuple); - } catch (InvocationTargetException e) { - throw new DbException("Error finalizing aggregate", e); - } - } - } - - @Override - public Schema getResultSchema() { - return resultSchema; - } - - @Override - public Object getInitialState() { - - return initialState.clone(); - } -} diff --git a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java index 9917a5bc7..51f9137db 100644 --- a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java +++ b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java @@ -1,65 +1,28 @@ package edu.washington.escience.myria.operator.agg; import java.io.IOException; -import java.util.Objects; +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; +import com.gs.collections.api.iterator.IntIterator; import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.operator.Operator; -import edu.washington.escience.myria.operator.UnaryOperator; -import edu.washington.escience.myria.storage.Tuple; import edu.washington.escience.myria.storage.TupleBatch; -import edu.washington.escience.myria.storage.TupleBatchBuffer; -import edu.washington.escience.myria.storage.TupleUtils; /** * This aggregate operator computes the aggregation in streaming manner (requires input sorted on grouping column(s)). - * This supports aggregation over multiple columns, with one or more group by columns. Intend to substitute - * SingleGroupByAggregate and MultiGroupByAggregate when input is known to be sorted. + * Intended to substitute for Aggregate when input is known to be sorted. * * @see Aggregate - * @see SingleGroupByAggregate - * @see MultiGroupByAggregate */ -public class StreamingAggregate extends UnaryOperator { - +public class StreamingAggregate extends Aggregate { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** The schema of the result. */ - private Schema resultSchema; - /** The schema of the columns indicated by the group keys. */ - private Schema groupSchema; - - /** Holds the current grouping key. */ - private Tuple curGroupKey; - /** Child of this aggregator. **/ - private final Operator child = getChild(); - /** Currently processing input tuple batch. **/ - private TupleBatch tb; - /** Current row in the input tuple batch. **/ - private int row; - - /** Group fields. **/ - private final int[] gFields; - /** An array [0, 1, .., gFields.length-1] used for comparing tuples. */ - private final int[] gRange; - /** Factories to make the Aggregators. **/ - private final AggregatorFactory[] factories; - /** The actual Aggregators. **/ - private Aggregator[] aggregators; - /** The state of the aggregators. */ - private Object[] aggregatorStates; - - /** Buffer for holding intermediate results. */ - private transient TupleBatchBuffer resultBuffer; - /** * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. * @@ -71,16 +34,7 @@ public StreamingAggregate( @Nullable final Operator child, @Nonnull final int[] gfields, @Nonnull final AggregatorFactory... factories) { - super(child); - gFields = Objects.requireNonNull(gfields, "gfields"); - this.factories = Objects.requireNonNull(factories, "factories"); - Preconditions.checkArgument(gfields.length > 0, " must have at least one group by field"); - Preconditions.checkArgument( - factories.length > 0, "to use StreamingAggregate, must specify some aggregates"); - gRange = new int[gfields.length]; - for (int i = 0; i < gRange.length; ++i) { - gRange[i] = i; - } + super(child, gfields, factories); } /** @@ -93,118 +47,46 @@ public StreamingAggregate( */ @Override protected TupleBatch fetchNextReady() throws DbException { - if (child.eos()) { - return null; - } - if (tb == null) { - tb = child.nextReady(); - row = 0; - } + final Operator child = getChild(); + TupleBatch tb = child.nextReady(); while (tb != null) { - while (row < tb.numTuples()) { - if (curGroupKey == null) { - /* First time accessing this tb, no aggregation performed previously. */ - // store current group key as a tuple - curGroupKey = new Tuple(groupSchema); - for (int gKey = 0; gKey < gFields.length; ++gKey) { - TupleUtils.copyValue(tb, gFields[gKey], row, curGroupKey, gKey); - } - } else if (!TupleUtils.tupleEquals(tb, gFields, row, curGroupKey, gRange, 0)) { - /* Different grouping key than current one, flush current agg result to result buffer. */ - addToResult(); - // store current group key as a tuple - for (int gKey = 0; gKey < gFields.length; ++gKey) { - TupleUtils.copyValue(tb, gFields[gKey], row, curGroupKey, gKey); - } - aggregatorStates = AggUtils.allocateAggStates(aggregators); - if (resultBuffer.hasFilledTB()) { - return resultBuffer.popFilled(); + for (int row = 0; row < tb.numTuples(); ++row) { + IntIterator iter = groupStates.getIndices(tb, gfields, row).intIterator(); + int index; + if (!iter.hasNext()) { + /* A new group is encountered. Since input tuples are sorted on the grouping key, the previous group must be + * finished so we can add its state to the result. */ + generateResult(); + groupStates.addTuple(tb, gfields, row, true); + int offset = gfields.length; + for (Aggregator agg : internalAggs) { + agg.initState(groupStates.getData(), offset); + offset += agg.getStateSize(); } + index = groupStates.getData().numTuples() - 1; + } else { + index = iter.next(); } - // update aggregator states with current tuple - for (int agg = 0; agg < aggregators.length; ++agg) { - aggregators[agg].addRow(tb, row, aggregatorStates[agg]); + int offset = gfields.length; + for (Aggregator agg : internalAggs) { + agg.addRow(tb, row, groupStates.getData(), index, offset); + offset += agg.getStateSize(); } - row++; + } + if (resultBuffer.hasFilledTB()) { + return resultBuffer.popFilled(); } tb = child.nextReady(); - row = 0; } - - /* - * We know that child.nextReady() has returned null, so we have processed all tuple we can. Child is - * either EOS or we have to wait for more data. - */ if (child.eos()) { - addToResult(); + generateResult(); return resultBuffer.popAny(); } return null; } - /** - * Add aggregate results with previous grouping key to result buffer. - * - * @throws DbException if any error - */ - private void addToResult() throws DbException { - int fromIndex = 0; - for (; fromIndex < curGroupKey.numColumns(); ++fromIndex) { - TupleUtils.copyValue(curGroupKey, fromIndex, 0, resultBuffer, fromIndex); - } - for (int agg = 0; agg < aggregators.length; ++agg) { - try { - aggregators[agg].getResult(resultBuffer, fromIndex, aggregatorStates[agg]); - } catch (Exception e) { - throw new DbException(e); - } - fromIndex += aggregators[agg].getResultSchema().numColumns(); - } - } - - /** - * The schema of the aggregate output. Grouping fields first and then aggregate fields. - * - * @return the resulting schema - */ - @Override - protected Schema generateSchema() { - Operator child = getChild(); - if (child == null) { - return null; - } - Schema inputSchema = child.getSchema(); - if (inputSchema == null) { - return null; - } - - groupSchema = inputSchema.getSubSchema(gFields); - resultSchema = Schema.of(groupSchema.getColumnTypes(), groupSchema.getColumnNames()); - try { - for (Aggregator agg : - AggUtils.allocateAggs(factories, inputSchema, getPythonFunctionRegistrar())) { - Schema curAggSchema = agg.getResultSchema(); - resultSchema = Schema.merge(resultSchema, curAggSchema); - } - } catch (DbException e) { - throw new RuntimeException("unable to allocate aggregators to determine output schema", e); - } - return resultSchema; - } - - @Override - protected void init(final ImmutableMap execEnvVars) throws DbException { - Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); - aggregators = - AggUtils.allocateAggs(factories, getChild().getSchema(), getPythonFunctionRegistrar()); - aggregatorStates = AggUtils.allocateAggStates(aggregators); - resultBuffer = new TupleBatchBuffer(getSchema()); - } - @Override - protected void cleanup() throws DbException { - aggregatorStates = null; - curGroupKey = null; - resultBuffer = null; + protected void addToResult(List> columns) { + resultBuffer.absorb(new TupleBatch(getSchema(), columns), false); } } diff --git a/src/edu/washington/escience/myria/operator/agg/StringAggregator.java b/src/edu/washington/escience/myria/operator/agg/StringAggregator.java index aeb878ef7..40cc0e123 100644 --- a/src/edu/washington/escience/myria/operator/agg/StringAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/StringAggregator.java @@ -1,141 +1,95 @@ package edu.washington.escience.myria.operator.agg; -import java.util.Objects; -import java.util.Set; - import com.google.common.collect.ImmutableSet; -import com.google.common.math.LongMath; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.ReadableColumn; +import edu.washington.escience.myria.storage.ReplaceableColumn; +import edu.washington.escience.myria.storage.TupleBatch; /** * Knows how to compute some aggregate over a StringColumn. */ public final class StringAggregator extends PrimitiveAggregator { - /** Required for Java serialization. */ - private static final long serialVersionUID = 1L; - /** Which column of the input this aggregator operates over. */ - private final int fromColumn; - - /** - * Aggregate operations applicable for string columns. - */ - public static final Set AVAILABLE_AGG = - ImmutableSet.of(AggregationOp.COUNT, AggregationOp.MAX, AggregationOp.MIN); - - /** - * @param aFieldName aggregate field name for use in output schema. - * @param aggOps the aggregate operation to simultaneously compute. - * @param column the column being aggregated over. - */ - public StringAggregator(final String aFieldName, final AggregationOp[] aggOps, final int column) { - super(aFieldName, aggOps); - fromColumn = column; + protected StringAggregator(final String inputName, final int column, final AggregationOp aggOp) { + super(inputName, column, aggOp); } - @Override - public void add(final ReadableTable from, final Object state) { - Objects.requireNonNull(from, "from"); - StringAggState sstate = (StringAggState) state; - final int numTuples = from.numTuples(); - if (numTuples == 0) { - return; - } - if (needsCount) { - sstate.count = LongMath.checkedAdd(sstate.count, numTuples); - } - if (needsStats) { - for (int i = 0; i < numTuples; ++i) { - addStringStats(from.getString(fromColumn, i), sstate); - } - } - } - - @Override - public void addRow(final ReadableTable table, final int row, final Object state) { - Objects.requireNonNull(table, "table"); - StringAggState sstate = (StringAggState) state; - if (needsCount) { - sstate.count = LongMath.checkedAdd(sstate.count, 1); - } - if (needsStats) { - addStringStats(table.getString(fromColumn, row), sstate); - } - } + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; - /** - * Helper function to add value to this aggregator. Note this does NOT update count. - * - * @param value the value to be added - * @param state the state of the aggregate, which will be mutated. - */ - private void addStringStats(final String value, final StringAggState state) { - Objects.requireNonNull(value, "value"); - if (needsMin) { - if ((state.min == null) || (state.min.compareTo(value) > 0)) { - state.min = value; - } - } - if (needsMax) { - if (state.max == null || state.max.compareTo(value) < 0) { - state.max = value; - } - } - } + /** Placeholder as MIN/MAX value of String. */ + private static final String STRING_INIT_VALUE = null; @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) { - StringAggState sstate = (StringAggState) state; - int idx = destColumn; - for (AggregationOp op : aggOps) { - switch (op) { - case COUNT: - dest.putLong(idx, sstate.count); - break; - case MAX: - dest.putString(idx, sstate.max); + public void addRow( + final TupleBatch from, + final int fromRow, + final MutableTupleBuffer to, + final int toRow, + final int offset) { + ReadableColumn fromCol = from.asColumn(column); + ReplaceableColumn toCol = to.getColumn(offset, toRow); + final int inColumnRow = to.getInColumnIndex(toRow); + switch (aggOp) { + case COUNT: + toCol.replaceLong(toCol.getLong(inColumnRow) + 1, inColumnRow); + break; + case MAX: + { + String value = toCol.getString(inColumnRow); + if (value == null || value.compareTo(fromCol.getString(fromRow)) < 0) { + toCol.replaceString(fromCol.getString(fromRow), inColumnRow); + } break; - case MIN: - dest.putString(idx, sstate.min); + } + case MIN: + { + String value = toCol.getString(inColumnRow); + if (value == null || value.compareTo(fromCol.getString(fromRow)) > 0) { + toCol.replaceString(fromCol.getString(fromRow), inColumnRow); + } break; - case AVG: - case STDEV: - case SUM: - throw new UnsupportedOperationException("Aggregate " + op + " on type String"); - } + } + default: + throw new IllegalArgumentException(aggOp + " is invalid"); } } @Override - protected Type getSumType() { - throw new UnsupportedOperationException("SUM of String values"); - } - - @Override - public Type getType() { - return Type.STRING_TYPE; + protected boolean isSupported(final AggregationOp aggOp) { + return ImmutableSet.of(AggregationOp.COUNT, AggregationOp.MIN, AggregationOp.MAX) + .contains(aggOp); } @Override - protected Set getAvailableAgg() { - return AVAILABLE_AGG; - } + protected Type getOutputType() { + switch (aggOp) { + case COUNT: + return Type.LONG_TYPE; + case MAX: + case MIN: + return Type.STRING_TYPE; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } + }; @Override - public Object getInitialState() { - return new StringAggState(); - } - - /** Private internal class that wraps the state required by this Aggregator as an object. */ - private final class StringAggState { - /** The number of tuples seen so far. */ - private long count = 0; - /** The minimum value in the aggregated column. */ - private String min = null; - /** The maximum value in the aggregated column. */ - private String max = null; + public void appendInitValue(AppendableTable data, final int column) { + switch (aggOp) { + case COUNT: + data.putLong(column, 0); + break; + case MIN: + case MAX: + data.putString(column, STRING_INIT_VALUE); + break; + default: + throw new IllegalArgumentException("Type " + aggOp + " is invalid"); + } } } diff --git a/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregator.java b/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregator.java index 1456747c3..92812166e 100644 --- a/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregator.java +++ b/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregator.java @@ -1,17 +1,12 @@ package edu.washington.escience.myria.operator.agg; -import java.io.IOException; -import java.lang.reflect.InvocationTargetException; import java.util.List; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; -import edu.washington.escience.myria.expression.evaluate.ScriptEvalInterface; -import edu.washington.escience.myria.storage.AppendableTable; -import edu.washington.escience.myria.storage.ReadableTable; -import edu.washington.escience.myria.storage.Tuple; +import edu.washington.escience.myria.storage.MutableTupleBuffer; import edu.washington.escience.myria.storage.TupleBatch; /** @@ -24,100 +19,64 @@ public class UserDefinedAggregator implements Aggregator { private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(UserDefinedAggregator.class); - /** - * The state of the aggregate variables. - */ - protected final Tuple initialState; - /** - * Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updateExpressions}. - */ - protected final ScriptEvalInterface updateEvaluator; - /** - * One evaluator for each expression in {@link #emitExpressions}. - */ - protected final List emitEvaluators; - /** - * One evaluator for each python expression. - */ - protected final List pyUDFEvaluators; + /** Evaluators that initialize the state. */ + protected final List initEvaluators; + /** Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updateExpressions}. */ + protected final List updateEvaluators; + /** The Schema of the state. */ + private final Schema stateSchema; /** - * The Schema of the tuples produced by this aggregator. - */ - protected final Schema resultSchema; - - /** - * @param state the initialized state of the tuple - * @param updateEvaluator updates the state given an input row - * @param pyUDFEvaluators for python expression evaluation. + * @param initEvaluators initialize the state + * @param updateEvaluators updates the state given an input row + * @param pyUpdateEvaluators for python expression evaluation. * @param emitEvaluators the evaluators that finalize the state * @param resultSchema the schema of the tuples produced by this aggregator + * @param stateSchema the schema of the state */ public UserDefinedAggregator( - final Tuple state, - final ScriptEvalInterface updateEvaluator, - final List pyUDFEvaluators, - final List emitEvaluators, - final Schema resultSchema) { - initialState = state; - this.updateEvaluator = updateEvaluator; - this.emitEvaluators = emitEvaluators; - this.pyUDFEvaluators = pyUDFEvaluators; - this.resultSchema = resultSchema; + final List initEvaluators, + final List updateEvaluators, + final Schema resultSchema, + final Schema stateSchema) { + this.initEvaluators = initEvaluators; + this.updateEvaluators = updateEvaluators; + this.stateSchema = stateSchema; } @Override - public void add(final ReadableTable from, final Object state) throws DbException { - for (int row = 0; row < from.numTuples(); ++row) { - addRow(from, row, state); - } + public int getStateSize() { + return stateSchema.numColumns(); } @Override - public void addRow(final ReadableTable from, final int row, final Object state) + public void addRow( + TupleBatch input, int inputRow, MutableTupleBuffer state, int stateRow, final int offset) throws DbException { - if (pyUDFEvaluators.size() > 0) { - throw new DbException("this aggregate has python UDF, StatefulAggreagte should be called!"); + for (GenericEvaluator eval : updateEvaluators) { + eval.updateState(input, inputRow, state, stateRow, offset); } - Tuple stateTuple = (Tuple) state; - - try { - if (updateEvaluator != null) { - updateEvaluator.evaluate(from, row, stateTuple, stateTuple); - } + } - } catch (Exception e) { - throw new DbException("Error updating UDA state", e); + @Override + public void initState(final MutableTupleBuffer state, final int offset) throws DbException { + for (GenericEvaluator eval : initEvaluators) { + eval.updateState(null, 0, state, 0, offset); } } - @Override - public void getResult(final AppendableTable dest, final int destColumn, final Object state) + /** + * @param tb + * @param offset + * @throws DbException + */ + public void finalizePythonUpdaters(final MutableTupleBuffer tb, final int offset) throws DbException { - Tuple stateTuple = (Tuple) state; - for (int index = 0; index < emitEvaluators.size(); index++) { - final GenericEvaluator evaluator = emitEvaluators.get(index); - try { - evaluator.eval(null, 0, null, dest.asWritableColumn(destColumn + index), stateTuple); - } catch (InvocationTargetException e) { - throw new DbException("Error finalizing aggregate", e); + for (int i = 0; i < updateEvaluators.size(); ++i) { + GenericEvaluator eval = updateEvaluators.get(i); + if (eval instanceof PythonUDFEvaluator) { + ((PythonUDFEvaluator) eval).evalGroups(tb, offset + i); } } } - - @Override - public Schema getResultSchema() { - return resultSchema; - } - - @Override - public Object getInitialState() { - - return initialState.clone(); - } - - @Override - public void add(final List from) throws DbException { - throw new DbException(" method not implemented"); - } } diff --git a/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregatorFactory.java b/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregatorFactory.java index 096a6895c..8f689a16d 100644 --- a/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregatorFactory.java +++ b/src/edu/washington/escience/myria/operator/agg/UserDefinedAggregatorFactory.java @@ -6,15 +6,12 @@ import javax.annotation.Nonnull; -import org.codehaus.commons.compiler.CompileException; import org.codehaus.commons.compiler.CompilerFactoryFactory; import org.codehaus.commons.compiler.IScriptEvaluator; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.MyriaConstants; @@ -24,9 +21,7 @@ import edu.washington.escience.myria.expression.evaluate.ExpressionOperatorParameter; import edu.washington.escience.myria.expression.evaluate.GenericEvaluator; import edu.washington.escience.myria.expression.evaluate.PythonUDFEvaluator; -import edu.washington.escience.myria.expression.evaluate.ScriptEvalInterface; import edu.washington.escience.myria.functions.PythonFunctionRegistrar; -import edu.washington.escience.myria.storage.Tuple; /** * Apply operator that has to be initialized and carries a state while new tuples are generated. @@ -44,32 +39,13 @@ public class UserDefinedAggregatorFactory implements AggregatorFactory { @JsonProperty private final List updaters; /** Expressions that emit the final aggregation result from the state. */ @JsonProperty private final List emitters; + /** Evaluators that initialize the {@link #state}. */ + private List initEvaluators; + /** Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updaters}. */ + private List updateEvaluators; + /** The schema of the result tuples. */ + private Schema resultSchema; - /** - * The states that are passed during execution. - */ - private transient Tuple state; - /** - * Evaluators that update the {@link #state}. One evaluator for each expression in {@link #updaters}. - */ - private transient ScriptEvalInterface updateEvaluator; - /** - * Evaluators for python expressions. One evaluator for each expression in {@link #updaters}. - */ - private transient ArrayList pyUpdateEvaluators; - /** - * One evaluator for each expression in {@link #emitters}. - */ - private transient ArrayList emitEvaluators; - /** - * Does the UDA have a pythonExpression? - */ - private transient boolean bHasPyEval = false; - - /** - * The schema of the result tuples. - */ - private transient Schema resultSchema; /** * Construct a new user-defined aggregate. The initializers set the initial state of the aggregate; the updaters * update this state for every new tuple. The emitters produce the final value of the aggregate. Note that there must @@ -87,146 +63,110 @@ public UserDefinedAggregatorFactory( this.initializers = Objects.requireNonNull(initializers, "initializers"); this.updaters = Objects.requireNonNull(updaters, "updaters"); this.emitters = Objects.requireNonNull(emitters, "emitters"); - state = null; - updateEvaluator = null; - pyUpdateEvaluators = null; - emitEvaluators = null; resultSchema = null; + updateEvaluators = new ArrayList(); + initEvaluators = new ArrayList(); } @Override - public Aggregator get(final Schema inputSchema) throws DbException { - throw new DbException("should call get with pyFuncReg"); + public List generateEmitExpressions(final Schema inputSchema) throws DbException { + return emitters; } @Override - @Nonnull - public Aggregator get(final Schema inputSchema, final PythonFunctionRegistrar pyFuncReg) - throws DbException { - - if (state == null) { - Objects.requireNonNull(inputSchema, "inputSchema"); - Preconditions.checkArgument( - initializers.size() == updaters.size(), - "must have the same number of aggregate state initializers (%s) and updaters (%s)", - initializers.size(), - updaters.size()); - // Verify that initializers and updaters have compatible names - for (int i = 0; i < initializers.size(); i++) { - Preconditions.checkArgument( - Objects.equals(initializers.get(i).getOutputName(), updaters.get(i).getOutputName()), - "initializers[i] and updaters[i] have different names (%s) != (%s)", - initializers.get(i).getOutputName(), - updaters.get(i).getOutputName()); - } - - /* Initialize the state. */ - Schema stateSchema = generateStateSchema(inputSchema); - state = new Tuple(stateSchema); - ScriptEvalInterface stateEvaluator = - getEvalScript(initializers, new ExpressionOperatorParameter(inputSchema), null); - stateEvaluator.evaluate(null, 0, state, null); - - /* Set up the updaters. */ - - pyUpdateEvaluators = new ArrayList<>(); - updateEvaluator = - getEvalScript( - updaters, new ExpressionOperatorParameter(inputSchema, stateSchema), pyFuncReg); - - /* Set up the emitters. */ - emitEvaluators = new ArrayList<>(); - emitEvaluators.ensureCapacity(emitters.size()); - - for (Expression expr : emitters) { - GenericEvaluator evaluator; - if (expr.isRegisteredUDF()) { - evaluator = - new PythonUDFEvaluator( - expr, new ExpressionOperatorParameter(inputSchema, stateSchema), pyFuncReg); - bHasPyEval = true; - } else { - - evaluator = - new GenericEvaluator(expr, new ExpressionOperatorParameter(null, stateSchema)); - } - - evaluator.compile(); - emitEvaluators.add(evaluator); - } + public List generateInternalAggs(final Schema inputSchema) throws DbException { + /* Initialize the state. */ + for (int i = 0; i < initializers.size(); ++i) { + initEvaluators.add( + getEvaluator(initializers.get(i), new ExpressionOperatorParameter(inputSchema), i)); + } - /* Compute the result schema. */ - ExpressionOperatorParameter emitParams = new ExpressionOperatorParameter(null, stateSchema); - ImmutableList.Builder types = ImmutableList.builder(); - ImmutableList.Builder names = ImmutableList.builder(); - for (Expression e : emitters) { - types.add(e.getOutputType(emitParams)); - names.add(e.getOutputName()); + /* Set up the updaters. */ + Schema stateSchema = generateStateSchema(inputSchema); + ExpressionOperatorParameter para = + new ExpressionOperatorParameter(inputSchema, stateSchema, pyFuncReg); + for (int i = 0; i < updaters.size(); ++i) { + if (updaters.get(i).isRegisteredPythonUDF()) { + updateEvaluators.add(new PythonUDFEvaluator(updaters.get(i), para)); + } else { + updateEvaluators.add(getEvaluator(updaters.get(i), para, i)); } - resultSchema = new Schema(types, names); } - if (bHasPyEval) { - return new StatefulUserDefinedAggregator( - state.clone(), updateEvaluator, pyUpdateEvaluators, emitEvaluators, resultSchema); - } else { - return new UserDefinedAggregator( - state.clone(), updateEvaluator, pyUpdateEvaluators, emitEvaluators, resultSchema); + + /* Compute the result schema. */ + ExpressionOperatorParameter emitParams = new ExpressionOperatorParameter(null, stateSchema); + ImmutableList.Builder types = ImmutableList.builder(); + ImmutableList.Builder names = ImmutableList.builder(); + for (Expression e : emitters) { + types.add(e.getOutputType(emitParams)); + names.add(e.getOutputName()); } + resultSchema = new Schema(types, names); + return ImmutableList.of( + new UserDefinedAggregator(initEvaluators, updateEvaluators, resultSchema, stateSchema)); } /** - * Produce a {@link ScriptEvalInterface} from {@link Expression}s and {@link ExpressionOperatorParameter}s. This - * function produces the code for a Java script that executes all expressions in turn and appends the calculated - * values to the result. The values to be output are calculated completely before they are stored to the output, thus - * it is safe to pass the same object as input and output, e.g., in the case of updating state in an Aggregate. + * Produce a {@link GenericEvaluator} from {@link Expression} and {@link ExpressionOperatorParameter}s. This function + * produces the code for a Java script that executes all expressions in turn and appends the calculated values to the + * result. The values to be output are calculated completely before they are stored to the output, thus it is safe to + * pass the same object as input and output, e.g., in the case of updating state in an Aggregate. * * @param expressions one expression for each output column. * @param param the inputs that expressions may use, including the {@link Schema} of the expression inputs and - * worker-local variables. - * @param pyFuncReg python function registrar. - * + * worker-local variables. + * @param col the column index of the expression. * @return a compiled object that will run all the expressions and store them into the output. * @throws DbException if there is an error compiling the expressions. */ - private ScriptEvalInterface getEvalScript( - @Nonnull final List expressions, + private GenericEvaluator getEvaluator( + @Nonnull final Expression expr, @Nonnull final ExpressionOperatorParameter param, - final PythonFunctionRegistrar pyFuncReg) + final int col) throws DbException { - StringBuilder compute = new StringBuilder(); - StringBuilder output = new StringBuilder(); - for (int varCount = 0; varCount < expressions.size(); ++varCount) { - Expression expr = expressions.get(varCount); - - if (!expr.isRegisteredUDF()) { - Type type = expr.getOutputType(param); - compute - .append(type.toJavaType().getName()) - .append(" val") - .append(varCount) - .append(" = ") - .append(expr.getJavaExpression(param)) - .append(";\n"); - - output - .append(Expression.RESULT) - .append(".put") - .append((type != Type.BLOB_TYPE) ? type.toJavaObjectType().getSimpleName() : "Blob") - .append("(") - .append(varCount) - .append(", val") - .append(varCount) - .append(");\n"); - } else { - bHasPyEval = true; - PythonUDFEvaluator evaluator = new PythonUDFEvaluator(expr, param, pyFuncReg); - pyUpdateEvaluators.add(evaluator); - } + Type type = expr.getOutputType(param); + // type valI = expression; + compute + .append(type.toJavaType().getName()) + .append(" val") + .append(col) + .append(" = ") + .append(expr.getJavaExpression(param)) + .append(";\n"); + + if (param.getStateSchema() == null) { + // state.putType(I, valI); + compute + .append(Expression.STATE) + .append(".put") + .append(type == Type.BLOB_TYPE ? "Blob" : type.toJavaObjectType().getSimpleName()) + .append("(") + .append(col) + .append("+") + .append(Expression.STATECOLOFFSET) + .append(", val") + .append(col) + .append(");\n"); + } else { + // state.replaceType(I, stateRow, valI); + compute + .append(Expression.STATE) + .append(".replace") + .append(type == Type.BLOB_TYPE ? "Blob" : type.toJavaObjectType().getSimpleName()) + .append("(") + .append(col) + .append("+") + .append(Expression.STATECOLOFFSET) + .append(", ") + .append(Expression.STATEROW) + .append(", val") + .append(col) + .append(");\n"); } - String script = compute.append(output).toString(); + String script = compute.toString(); LOGGER.debug("Compiling UDA {}", script); IScriptEvaluator se; @@ -237,37 +177,38 @@ private ScriptEvalInterface getEvalScript( throw new DbException("Could not create scriptevaluator", e); } se.setDefaultImports(MyriaConstants.DEFAULT_JANINO_IMPORTS); + GenericEvaluator eval = new GenericEvaluator(expr, script, param); + eval.compile(); + return eval; + } - try { - if (script.length() > 1) { - return (ScriptEvalInterface) - se.createFastEvaluator( - script, - ScriptEvalInterface.class, - new String[] {Expression.TB, Expression.ROW, Expression.RESULT, Expression.STATE}); - } else { - return null; - } - } catch (CompileException e) { - LOGGER.debug("Error when compiling expression {}: {}", script, e); - throw new DbException("Error when compiling expression: " + script, e); + @Override + public Schema generateSchema(final Schema inputSchema) { + Schema stateSchema = generateStateSchema(inputSchema); + ImmutableList.Builder typesBuilder = ImmutableList.builder(); + ImmutableList.Builder namesBuilder = ImmutableList.builder(); + for (Expression expr : emitters) { + typesBuilder.add( + expr.getOutputType(new ExpressionOperatorParameter(inputSchema, stateSchema))); + namesBuilder.add(expr.getOutputName()); } + return Schema.of(typesBuilder.build(), namesBuilder.build()); } - /** - * Generate the schema of the state. - * - * @param inputSchema the {@link Schema} of the input tuples. - * @return the {@link Schema} of the state assuming the specified input types. - */ - private Schema generateStateSchema(final Schema inputSchema) { + @Override + public Schema generateStateSchema(final Schema inputSchema) { ImmutableList.Builder typesBuilder = ImmutableList.builder(); ImmutableList.Builder namesBuilder = ImmutableList.builder(); - for (Expression expr : initializers) { typesBuilder.add(expr.getOutputType(new ExpressionOperatorParameter(inputSchema))); namesBuilder.add(expr.getOutputName()); } - return new Schema(typesBuilder.build(), namesBuilder.build()); + return Schema.of(typesBuilder.build(), namesBuilder.build()); + } + + PythonFunctionRegistrar pyFuncReg; + + public void setPyFuncReg(PythonFunctionRegistrar pyFuncReg) { + this.pyFuncReg = pyFuncReg; } } diff --git a/src/edu/washington/escience/myria/parallel/Server.java b/src/edu/washington/escience/myria/parallel/Server.java index 7335bcb5b..4d5f8c3c3 100644 --- a/src/edu/washington/escience/myria/parallel/Server.java +++ b/src/edu/washington/escience/myria/parallel/Server.java @@ -107,10 +107,8 @@ import edu.washington.escience.myria.operator.RootOperator; import edu.washington.escience.myria.operator.TupleSink; import edu.washington.escience.myria.operator.agg.Aggregate; -import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; -import edu.washington.escience.myria.operator.agg.SingleColumnAggregatorFactory; -import edu.washington.escience.myria.operator.agg.SingleGroupByAggregate; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregatorFactory; import edu.washington.escience.myria.operator.network.CollectProducer; import edu.washington.escience.myria.operator.network.Consumer; import edu.washington.escience.myria.operator.network.GenericShuffleProducer; @@ -145,6 +143,7 @@ import edu.washington.escience.myria.util.IPCUtils; import edu.washington.escience.myria.util.concurrent.ErrorLoggingTimerTask; import edu.washington.escience.myria.util.concurrent.RenamingThreadFactory; + /** * The master entrance. */ @@ -919,7 +918,7 @@ public DatasetStatus parallelIngestDataset( } /** - * Helper method for parallel ingest. + * Helper method for parallel ingest. * * @param fileSize the size of the file to ingest * @param allWorkers all workers considered for ingest @@ -1123,7 +1122,7 @@ public long createView( * Create a function and register it in the catalog. * * @param name the name of the function - * @param definition the function definition - this is postgres specific for postgres and function text for python. + * @param definition the function definition - this is postgres specific for postgres and function text for python. * @param outputSchema the output schema of the function * @param isMultiValued indicates if the function returns multiple tuples. * @param lang this is the language of the function. @@ -1174,7 +1173,7 @@ public String createFunction( try { qf.get().getQueryId(); } catch (ExecutionException e) { - throw new DbException("Error executing query", e.getCause()); + throw new DbException("Error executing query", e); } } catch (CatalogException e) { throw new DbException(e); @@ -1187,8 +1186,8 @@ public String createFunction( } return response; } + /** - * * @return list of functions from the catalog * @throws DbException in case of error. */ @@ -1199,8 +1198,8 @@ public List getFunctions() throws DbException { throw new DbException(e); } } + /** - * * @param functionName : name of the function to retrieve. * @return functiondetails for the function * @throws DbException in case of error. @@ -1591,9 +1590,9 @@ public ListenableFuture startSentLogDataStream( final Consumer consumer = new Consumer(addWorkerId.getSchema(), operatorId, ImmutableSet.copyOf(actualWorkers)); - final MultiGroupByAggregate aggregate = - new MultiGroupByAggregate( - consumer, new int[] {0, 1, 2}, new SingleColumnAggregatorFactory(3, AggregationOp.SUM)); + final Aggregate aggregate = + new Aggregate( + consumer, new int[] {0, 1, 2}, new PrimitiveAggregatorFactory(3, AggregationOp.SUM)); // rename columns ImmutableList.Builder renameExpressions = ImmutableList.builder(); @@ -1733,13 +1732,13 @@ public ListenableFuture startAggregatedSentLogDataStream( final Consumer consumer = new Consumer(scan.getSchema(), operatorId, ImmutableSet.copyOf(actualWorkers)); - final SingleGroupByAggregate aggregate = - new SingleGroupByAggregate( + final Aggregate aggregate = + new Aggregate( consumer, - 0, - new SingleColumnAggregatorFactory(1, AggregationOp.SUM), - new SingleColumnAggregatorFactory(2, AggregationOp.MIN), - new SingleColumnAggregatorFactory(3, AggregationOp.MAX)); + new int[] {0}, + new PrimitiveAggregatorFactory(1, AggregationOp.SUM), + new PrimitiveAggregatorFactory(2, AggregationOp.MIN), + new PrimitiveAggregatorFactory(3, AggregationOp.MAX)); // rename columns ImmutableList.Builder renameExpressions = ImmutableList.builder(); @@ -1988,9 +1987,9 @@ public QueryFuture startHistogramDataStream( new Consumer(scan.getSchema(), operatorId, ImmutableSet.copyOf(actualWorkers)); // sum up the number of workers working - final MultiGroupByAggregate sumAggregate = - new MultiGroupByAggregate( - consumer, new int[] {0, 1}, new SingleColumnAggregatorFactory(1, AggregationOp.COUNT)); + final Aggregate sumAggregate = + new Aggregate( + consumer, new int[] {0, 1}, new PrimitiveAggregatorFactory(1, AggregationOp.COUNT)); // rename columns ImmutableList.Builder renameExpressions = ImmutableList.builder(); renameExpressions.add(new Expression("opId", new VariableExpression(0))); @@ -2072,8 +2071,9 @@ public QueryFuture startRangeDataStream( final Aggregate sumAggregate = new Aggregate( consumer, - new SingleColumnAggregatorFactory(0, AggregationOp.MIN), - new SingleColumnAggregatorFactory(1, AggregationOp.MAX)); + new int[] {}, + new PrimitiveAggregatorFactory(0, AggregationOp.MIN), + new PrimitiveAggregatorFactory(1, AggregationOp.MAX)); TupleSink output = new TupleSink(sumAggregate, writer, dataSink); final SubQueryPlan masterPlan = new SubQueryPlan(output); @@ -2148,9 +2148,9 @@ public QueryFuture startContributionsStream( new Consumer(scan.getSchema(), operatorId, ImmutableSet.copyOf(actualWorkers)); // sum up contributions - final SingleGroupByAggregate sumAggregate = - new SingleGroupByAggregate( - consumer, 0, new SingleColumnAggregatorFactory(1, AggregationOp.AVG)); + final Aggregate sumAggregate = + new Aggregate( + consumer, new int[] {0}, new PrimitiveAggregatorFactory(1, AggregationOp.AVG)); // rename columns ImmutableList.Builder renameExpressions = ImmutableList.builder(); diff --git a/src/edu/washington/escience/myria/storage/Field.java b/src/edu/washington/escience/myria/storage/Field.java deleted file mode 100644 index 7b34c51b8..000000000 --- a/src/edu/washington/escience/myria/storage/Field.java +++ /dev/null @@ -1,85 +0,0 @@ -package edu.washington.escience.myria.storage; - -import java.io.Serializable; -import java.nio.BufferOverflowException; - -import org.joda.time.DateTime; -import java.nio.ByteBuffer; -import edu.washington.escience.myria.column.builder.WritableColumn; -import edu.washington.escience.myria.util.MyriaUtils; - -/** - * A field used in {@link Tuple}. - * - * @param the type. - */ -public class Field> implements WritableColumn, Serializable { - /***/ - private static final long serialVersionUID = 1L; - - /** - * The value of this field. - */ - private Object value; - - @Override - public WritableColumn appendBoolean(final boolean value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendDateTime(final DateTime value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendDouble(final double value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendFloat(final float value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendInt(final int value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendLong(final long value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendObject(final Object value) throws BufferOverflowException { - this.value = MyriaUtils.ensureObjectIsValidType(value); - return this; - } - - @Override - public WritableColumn appendString(final String value) throws BufferOverflowException { - this.value = value; - return this; - } - - @Override - public WritableColumn appendBlob(final ByteBuffer value) throws BufferOverflowException { - this.value = value; - return this; - } - - /** - * @return the value - */ - public Object getObject() { - return value; - } -} diff --git a/src/edu/washington/escience/myria/storage/MutableTupleBuffer.java b/src/edu/washington/escience/myria/storage/MutableTupleBuffer.java index 4a3d20748..75636696c 100644 --- a/src/edu/washington/escience/myria/storage/MutableTupleBuffer.java +++ b/src/edu/washington/escience/myria/storage/MutableTupleBuffer.java @@ -1,12 +1,13 @@ package edu.washington.escience.myria.storage; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.BitSet; import java.util.List; import java.util.Objects; import org.joda.time.DateTime; -import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import edu.washington.escience.myria.Schema; @@ -34,10 +35,10 @@ public class MutableTupleBuffer implements ReadableTable, AppendableTable, Clone private int numColumnsReady; /** Internal state representing the number of tuples in the in-progress TupleBatch. */ private int currentInProgressTuples; - /** BatchSize.*/ + /** Batch Size. */ private int batchSize; + /** - * * @return batchSize for the tuple batch. */ public int getBatchSize() { @@ -74,7 +75,6 @@ public final void clear() { /** * Makes a batch of any tuples in the buffer and appends it to the internal list. - * */ private void finishBatch() { Preconditions.checkArgument(numColumnsReady == 0); @@ -103,42 +103,42 @@ public final int numTuples() { @Override @Deprecated public final Object getObject(final int col, final int row) { - return getColumn(col, row).getObject(row % batchSize); + return getColumn(col, row).getObject(getInColumnIndex(row)); } @Override public final boolean getBoolean(final int col, final int row) { - return getColumn(col, row).getBoolean(row % batchSize); + return getColumn(col, row).getBoolean(getInColumnIndex(row)); } @Override public final double getDouble(final int col, final int row) { - return getColumn(col, row).getDouble(row % batchSize); + return getColumn(col, row).getDouble(getInColumnIndex(row)); } @Override public final float getFloat(final int col, final int row) { - return getColumn(col, row).getFloat(row % batchSize); + return getColumn(col, row).getFloat(getInColumnIndex(row)); } @Override public final long getLong(final int col, final int row) { - return getColumn(col, row).getLong(row % batchSize); + return getColumn(col, row).getLong(getInColumnIndex(row)); } @Override public final int getInt(final int col, final int row) { - return getColumn(col, row).getInt(row % batchSize); + return getColumn(col, row).getInt(getInColumnIndex(row)); } @Override public final String getString(final int col, final int row) { - return getColumn(col, row).getString(row % batchSize); + return getColumn(col, row).getString(getInColumnIndex(row)); } @Override public final DateTime getDateTime(final int col, final int row) { - return getColumn(col, row).getDateTime(row % batchSize); + return getColumn(col, row).getDateTime(getInColumnIndex(row)); } @Override @@ -262,7 +262,8 @@ private void columnPut(final int column) { * @param sourceColumn the column from which data will be retrieved. * @param sourceRow the row in the source column from which data will be retrieved. */ - public final void put(final int destColumn, final Column sourceColumn, final int sourceRow) { + public final void put( + final int destColumn, final ReadableColumn sourceColumn, final int sourceRow) { TupleUtils.copyValue(sourceColumn, sourceRow, this, destColumn); } @@ -380,7 +381,7 @@ public final ReplaceableColumn getColumn(final int column, final int row) { * @param value the replacement. */ public final void replaceInt(final int destColumn, final int destRow, final int value) { - getColumn(destColumn, destRow).replaceInt(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceInt(value, getInColumnIndex(destRow)); } /** @@ -389,7 +390,7 @@ public final void replaceInt(final int destColumn, final int destRow, final int * @param value the replacement. */ public final void replaceLong(final int destColumn, final int destRow, final long value) { - getColumn(destColumn, destRow).replaceLong(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceLong(value, getInColumnIndex(destRow)); } /** @@ -398,7 +399,7 @@ public final void replaceLong(final int destColumn, final int destRow, final lon * @param value the replacement. */ public final void replaceFloat(final int destColumn, final int destRow, final float value) { - getColumn(destColumn, destRow).replaceFloat(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceFloat(value, getInColumnIndex(destRow)); } /** @@ -407,7 +408,7 @@ public final void replaceFloat(final int destColumn, final int destRow, final fl * @param value the replacement. */ public final void replaceDouble(final int destColumn, final int destRow, final double value) { - getColumn(destColumn, destRow).replaceDouble(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceDouble(value, getInColumnIndex(destRow)); } /** @@ -416,7 +417,7 @@ public final void replaceDouble(final int destColumn, final int destRow, final d * @param value the replacement. */ public final void replaceString(final int destColumn, final int destRow, final String value) { - getColumn(destColumn, destRow).replaceString(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceString(value, getInColumnIndex(destRow)); } /** @@ -426,15 +427,7 @@ public final void replaceString(final int destColumn, final int destRow, final S */ public final void replaceByteBuffer( final int destColumn, final int destRow, final ByteBuffer value) { - getColumn(destColumn, destRow).replaceBlob(value, destRow % batchSize); - } - /** - * @param destColumn the destination column. - * @param destRow the row. - * @param value the replacement. - */ - public final void replaceDateTime(final int destColumn, final int destRow, final DateTime value) { - getColumn(destColumn, destRow).replaceDateTime(value, destRow % batchSize); + getColumn(destColumn, destRow).replaceBlob(value, getInColumnIndex(destRow)); } /** @@ -446,9 +439,12 @@ public final void replaceDateTime(final int destColumn, final int destRow, final * @param sourceRow the row in the source column from which data will be retrieved. */ public final void replace( - final int destColumn, final int destRow, final Column sourceColumn, final int sourceRow) { + final int destColumn, + final int destRow, + final ReadableColumn sourceColumn, + final int sourceRow) { checkRowIndex(destRow); - int tupleIndex = destRow % batchSize; + int tupleIndex = getInColumnIndex(destRow); ReplaceableColumn dest = getColumn(destColumn, destRow); switch (dest.getType()) { case BOOLEAN_TYPE: @@ -539,4 +535,14 @@ public ReadableColumn asColumn(final int column) { public WritableColumn asWritableColumn(final int column) { return new WritableSubColumn(this, column); } + + /** + * Get the in-column row index of the given row index of the whole tuple buffer. + * + * @param row + * @return the in-column row index + */ + public int getInColumnIndex(final int row) { + return row % batchSize; + } } diff --git a/src/edu/washington/escience/myria/storage/Tuple.java b/src/edu/washington/escience/myria/storage/Tuple.java index 66470b50e..e5f6f049c 100644 --- a/src/edu/washington/escience/myria/storage/Tuple.java +++ b/src/edu/washington/escience/myria/storage/Tuple.java @@ -1,19 +1,17 @@ package edu.washington.escience.myria.storage; import java.io.Serializable; -import java.util.List; +import java.nio.ByteBuffer; import javax.annotation.Nonnull; import org.joda.time.DateTime; -import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.builder.WritableColumn; -import edu.washington.escience.myria.util.MyriaUtils; +import edu.washington.escience.myria.util.MyriaArrayUtils; /** * A single row relation. @@ -30,16 +28,26 @@ public class Tuple implements Cloneable, AppendableTable, ReadableTable, Seriali /** * The data of the tuple. */ - private final List> data; + private final MutableTupleBuffer data; /** * @param schema the schema of the tuple */ public Tuple(final Schema schema) { + data = new MutableTupleBuffer(schema); this.schema = schema; - data = Lists.newArrayListWithCapacity(numColumns()); - for (int i = 0; i < numColumns(); i++) { - data.add(new Field<>()); + } + + /** + * @param data + * @param row + * @param cols + */ + public Tuple(final ReadableTable data, final int row, final int[] cols) { + this.schema = data.getSchema().getSubSchema(cols); + this.data = new MutableTupleBuffer(schema); + for (int i = 0; i < cols.length; ++i) { + this.data.put(i, data.asColumn(cols[i]), row); } } @@ -51,70 +59,49 @@ public Schema getSchema() { return schema; } - /** - * Returns a value and checks arguments. - * - * @param column the column index. - * @param row the row index. - * @return the value at the desired position. - */ - private Object getValue(final int column, final int row) { - Preconditions.checkArgument(row == 0); - Preconditions.checkElementIndex(column, numColumns()); - return data.get(column).getObject(); - } - @Override public boolean getBoolean(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.BOOLEAN_TYPE); - return (boolean) getValue(column, row); + return data.getBoolean(column, row); } @Override public double getDouble(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.DOUBLE_TYPE); - return (double) getValue(column, row); + return data.getDouble(column, row); } @Override public float getFloat(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.FLOAT_TYPE); - return (float) getValue(column, row); + return data.getFloat(column, row); } @Override public int getInt(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.INT_TYPE); - return (int) getValue(column, row); + return data.getInt(column, row); } @Override public long getLong(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.LONG_TYPE); - return (long) getValue(column, row); + return data.getLong(column, row); } @Override public String getString(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.STRING_TYPE); - return (String) getValue(column, row); + return data.getString(column, row); } @Override public DateTime getDateTime(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.DATETIME_TYPE); - return (DateTime) getValue(column, row); + return data.getDateTime(column, row); } @Override public ByteBuffer getBlob(final int column, final int row) { - Preconditions.checkArgument(getSchema().getColumnType(column) == Type.BLOB_TYPE); - return (ByteBuffer) getValue(column, row); + return data.getBlob(column, row); } @Override public Object getObject(final int column, final int row) { - return getValue(column, row); + return data.getObject(column, row); } @Override @@ -127,24 +114,6 @@ public int numTuples() { return 1; } - /** - * @param columnIdx the column index - * @return the field at the index - */ - public Field getColumn(final int columnIdx) { - return data.get(columnIdx); - } - - /** - * Set value. - * - * @param columnIdx the column index - * @param value the value to set - */ - public void set(final int columnIdx, final Object value) { - getColumn(columnIdx).appendObject(MyriaUtils.ensureObjectIsValidType(value)); - } - @Override public ReadableColumn asColumn(final int column) { return new ReadableSubColumn( @@ -153,60 +122,62 @@ public ReadableColumn asColumn(final int column) { @Override public Tuple clone() { - Tuple t = new Tuple(getSchema()); - for (int i = 0; i < numColumns(); ++i) { - t.set(i, getObject(i, 0)); - } - return t; + return new Tuple(data.clone(), 0, MyriaArrayUtils.range(0, numColumns())); } @Override public void putBoolean(final int column, final boolean value) { - set(column, value); + data.putBoolean(column, value); } @Override public void putDateTime(final int column, @Nonnull final DateTime value) { - set(column, value); + data.putDateTime(column, value); } @Override public void putDouble(final int column, final double value) { - set(column, value); + data.putDouble(column, value); } @Override public void putFloat(final int column, final float value) { - set(column, value); + data.putFloat(column, value); } @Override public void putInt(final int column, final int value) { - set(column, value); + data.putInt(column, value); } @Override public void putLong(final int column, final long value) { - set(column, value); + data.putLong(column, value); } @Override public void putString(final int column, final @Nonnull String value) { - set(column, value); + data.putString(column, value); } @Override + @Deprecated public void putObject(final int column, final @Nonnull Object value) { - set(column, value); + data.putObject(column, value); } @Override public void putBlob(final int column, @Nonnull final ByteBuffer value) { - set(column, value); + data.putBlob(column, value); } @Override public WritableColumn asWritableColumn(final int column) { - return data.get(column); + return data.asWritableColumn(column); + } + + @Override + public final String toString() { + return data.getAll().get(0).toString(); } } diff --git a/src/edu/washington/escience/myria/storage/TupleBatchBuffer.java b/src/edu/washington/escience/myria/storage/TupleBatchBuffer.java index 821a585b2..22f63c3ee 100644 --- a/src/edu/washington/escience/myria/storage/TupleBatchBuffer.java +++ b/src/edu/washington/escience/myria/storage/TupleBatchBuffer.java @@ -21,8 +21,6 @@ /** * Used for creating TupleBatch objects on the fly. A helper class used in, e.g., the Scatter operator. Currently it * doesn't support random access to a specific cell. Use TupleBuffer instead. - * - * */ public class TupleBatchBuffer implements AppendableTable { /** Format of the emitted tuples. */ @@ -75,10 +73,8 @@ public int getBatchSize() { * @param tb the TB. */ public final void appendTB(final TupleBatch tb) { - /* - * If we're currently building a batch, we better finish it before we append this one to the list. Otherwise - * reordering will happen. - */ + /* If we're currently building a batch, we better finish it before we append this one to the list. Otherwise + * reordering will happen. */ finishBatch(); readyTuplesNum += tb.numTuples(); @@ -331,7 +327,6 @@ public final void put(final int column, final Object value) { * @param rightTb the right tuple batch * @param rightIdx the index of the right tuple in the tuple batch * @param rightAnswerColumns an array that specifies which columns from the right tuple batch - * */ public final void put( final TupleBatch leftTb, @@ -493,9 +488,10 @@ private void updateLastPoppedTime() { * size of the TupleBatch because it is a full copy. * * @param tupleBatch the tuple data to be added to this buffer. + * @param shallowCopy shallow or deep copy of tupleBatch elements. */ - public void absorb(final TupleBatch tupleBatch) { - if (currentInProgressTuples == 0) { + public void absorb(final TupleBatch tupleBatch, final boolean shallowCopy) { + if (shallowCopy) { appendTB(tupleBatch); } else { tupleBatch.compactInto(this); diff --git a/src/edu/washington/escience/myria/storage/TupleBuffer.java b/src/edu/washington/escience/myria/storage/TupleBuffer.java index 0d9020a42..7c706a731 100644 --- a/src/edu/washington/escience/myria/storage/TupleBuffer.java +++ b/src/edu/washington/escience/myria/storage/TupleBuffer.java @@ -1,12 +1,13 @@ package edu.washington.escience.myria.storage; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.BitSet; import java.util.List; import java.util.Objects; import org.joda.time.DateTime; -import java.nio.ByteBuffer; + import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -44,9 +45,12 @@ public class TupleBuffer implements ReadableTable, AppendableTable { private ImmutableList finalBatches; /** The number of tuples in this buffer. */ private int numTuples; - /** BatchSize.*/ + /** Batch size. */ private int batchSize; + /** + * @return the size of the batches. + */ public int getBatchSize() { return batchSize; } @@ -69,6 +73,7 @@ public TupleBuffer(final Schema schema) { numTuples = 0; batchSize = TupleUtils.getBatchSize(schema); } + /** * Constructs an empty TupleBuffer to hold tuples matching the specified Schema. * @@ -81,7 +86,6 @@ public TupleBuffer(final Schema schema, int batchSize) { /** * Makes a batch of any tuples in the buffer and appends it to the internal list. - * */ private void finishBatch() { Preconditions.checkState( diff --git a/src/edu/washington/escience/myria/util/HashUtils.java b/src/edu/washington/escience/myria/util/HashUtils.java index c733e5888..503388c74 100644 --- a/src/edu/washington/escience/myria/util/HashUtils.java +++ b/src/edu/washington/escience/myria/util/HashUtils.java @@ -155,9 +155,8 @@ private static Hasher addValue(final Hasher hasher, final ReadableColumn column, return hasher.putLong(column.getLong(row)); case STRING_TYPE: return hasher.putObject(column.getString(row), TypeFunnel.INSTANCE); - case BLOB_TYPE: - return hasher.putObject(column.getBlob(row), TypeFunnel.INSTANCE); + default: + throw new UnsupportedOperationException("Hashing a column of type " + column.getType()); } - throw new UnsupportedOperationException("Hashing a column of type " + column.getType()); } } diff --git a/src/edu/washington/escience/myria/util/MyriaArrayUtils.java b/src/edu/washington/escience/myria/util/MyriaArrayUtils.java index 2f50b7f0d..efcdf1b62 100644 --- a/src/edu/washington/escience/myria/util/MyriaArrayUtils.java +++ b/src/edu/washington/escience/myria/util/MyriaArrayUtils.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.List; import java.util.Set; +import java.util.stream.IntStream; import org.slf4j.LoggerFactory; @@ -153,4 +154,15 @@ public static int[] checkPositionIndices(final int[] arrayOfIndices, final int s } return arrayOfIndices; } + + /** + * Helper function that generates an array of the numbers in [start, start+length). + * + * @param start the size of the array. + * @param length the length of the array. + * @return an array of the numbers [start, start+length). + */ + public static int[] range(final int start, final int length) { + return IntStream.range(start, start + length).toArray(); + } } diff --git a/test/edu/washington/escience/myria/operator/AggregateTest.java b/test/edu/washington/escience/myria/operator/AggregateTest.java index 9f59cd790..a10370667 100644 --- a/test/edu/washington/escience/myria/operator/AggregateTest.java +++ b/test/edu/washington/escience/myria/operator/AggregateTest.java @@ -36,10 +36,8 @@ import edu.washington.escience.myria.column.builder.StringColumnBuilder; import edu.washington.escience.myria.operator.agg.Aggregate; import edu.washington.escience.myria.operator.agg.AggregatorFactory; -import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; -import edu.washington.escience.myria.operator.agg.SingleColumnAggregatorFactory; -import edu.washington.escience.myria.operator.agg.SingleGroupByAggregate; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregatorFactory; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.storage.TupleBuffer; @@ -84,7 +82,7 @@ private void allNumericAggsTestSchema(final Schema schema, final Type type) { } /** - * For ensure that the given Schema matches the expected non-numeric aggregate types for the given Type. + * Ensure that the given Schema matches the expected non-numeric aggregate types for the given Type. * * All non-numeric aggs, in order: COUNT, MIN, MAX * @@ -120,7 +118,7 @@ private TupleBatch makeTrivialTupleBatch(final ColumnBuilder builder) { * @param builder the tuples to be aggregated * @param aggOps the aggregate operations over the column * @param noColumns whether to group by no columns (if true) or to append a constant value single column and group by - * it (if false). + * it (if false). * @return a single TupleBatch containing the results of the aggregation * @throws Exception if there is an error */ @@ -131,11 +129,8 @@ private TupleBatch doAggOpsToCol( return doAggOpsToSingleCol(builder, aggOps); } BatchTupleSource source = new BatchTupleSource(makeTrivialTupleBatch(builder)); - AggregatorFactory[] aggs = new AggregatorFactory[aggOps.length]; - for (int i = 0; i < aggs.length; ++i) { - aggs[i] = new SingleColumnAggregatorFactory(0, aggOps[i]); - } - Aggregate agg = new Aggregate(source, aggs); + AggregatorFactory aggs = new PrimitiveAggregatorFactory(0, aggOps); + Aggregate agg = new Aggregate(source, new int[] {}, aggs); /* Do it -- this should cause an error. */ agg.open(TestEnvVars.get()); TupleBatch tb = agg.nextReady(); @@ -147,9 +142,6 @@ private TupleBatch doAggOpsToCol( * Helper function to instantiate an aggregator and do the aggregation. Do not use if more than one TupleBatch are * expected. * - * This variant uses a SingleGroupByAggregate in order to do extra testing of the Aggregators by hitting a function - * that Aggregate does not use. - * * @param builder the tuples to be aggregated * @param aggOps the aggregate operations over the column * @return a single TupleBatch containing the results of the aggregation @@ -166,9 +158,9 @@ private TupleBatch doAggOpsToSingleCol( BatchTupleSource source = new BatchTupleSource(new TupleBatch(newSchema, columns)); AggregatorFactory[] aggs = new AggregatorFactory[aggOps.length]; for (int i = 0; i < aggs.length; ++i) { - aggs[i] = new SingleColumnAggregatorFactory(0, aggOps[i]); + aggs[i] = new PrimitiveAggregatorFactory(0, aggOps[i]); } - SingleGroupByAggregate agg = new SingleGroupByAggregate(source, trivialTb.numColumns(), aggs); + Aggregate agg = new Aggregate(source, new int[] {trivialTb.numColumns()}, aggs); /* Do it -- this should cause an error. */ agg.open(TestEnvVars.get()); TupleBatch tb = agg.nextReady(); @@ -398,11 +390,11 @@ public void testSingleGroupAvg() throws DbException, InterruptedException { final TupleBatchBuffer testBase = generateRandomTuples(numTuples); // group by name, aggregate on id - final SingleGroupByAggregate agg = - new SingleGroupByAggregate( + final Aggregate agg = + new Aggregate( new BatchTupleSource(testBase), - 1, - new SingleColumnAggregatorFactory(0, AggregationOp.AVG)); + new int[] {1}, + new PrimitiveAggregatorFactory(0, AggregationOp.AVG)); agg.open(TestEnvVars.get()); TupleBatch tb = null; final TupleBatchBuffer result = new TupleBatchBuffer(agg.getSchema()); @@ -423,11 +415,11 @@ public void testSingleGroupMax() throws DbException, InterruptedException { final TupleBatchBuffer testBase = generateRandomTuples(numTuples); // group by name, aggregate on id - SingleGroupByAggregate agg = - new SingleGroupByAggregate( + Aggregate agg = + new Aggregate( new BatchTupleSource(testBase), - 1, - new SingleColumnAggregatorFactory(0, AggregationOp.MAX)); + new int[] {1}, + new PrimitiveAggregatorFactory(0, AggregationOp.MAX)); agg.open(TestEnvVars.get()); TupleBatch tb = null; TupleBatchBuffer result = new TupleBatchBuffer(agg.getSchema()); @@ -442,10 +434,10 @@ public void testSingleGroupMax() throws DbException, InterruptedException { TestUtils.assertTupleBagEqual(TestUtils.groupByMax(testBase, 1, 0), actualResult); agg = - new SingleGroupByAggregate( + new Aggregate( new BatchTupleSource(testBase), - 0, - new SingleColumnAggregatorFactory(1, AggregationOp.MAX)); + new int[] {0}, + new PrimitiveAggregatorFactory(1, AggregationOp.MAX)); agg.open(TestEnvVars.get()); tb = null; result = new TupleBatchBuffer(agg.getSchema()); @@ -466,11 +458,11 @@ public void testSingleGroupMin() throws DbException, InterruptedException { final TupleBatchBuffer testBase = generateRandomTuples(numTuples); // group by name, aggregate on id - SingleGroupByAggregate agg = - new SingleGroupByAggregate( + Aggregate agg = + new Aggregate( new BatchTupleSource(testBase), - 1, - new SingleColumnAggregatorFactory(0, AggregationOp.MIN)); + new int[] {1}, + new PrimitiveAggregatorFactory(0, AggregationOp.MIN)); agg.open(TestEnvVars.get()); TupleBatch tb = null; TupleBatchBuffer result = new TupleBatchBuffer(agg.getSchema()); @@ -485,10 +477,10 @@ public void testSingleGroupMin() throws DbException, InterruptedException { TestUtils.assertTupleBagEqual(TestUtils.groupByMin(testBase, 1, 0), actualResult); agg = - new SingleGroupByAggregate( + new Aggregate( new BatchTupleSource(testBase), - 0, - new SingleColumnAggregatorFactory(1, AggregationOp.MIN)); + new int[] {0}, + new PrimitiveAggregatorFactory(1, AggregationOp.MIN)); agg.open(TestEnvVars.get()); tb = null; result = new TupleBatchBuffer(agg.getSchema()); @@ -509,11 +501,11 @@ public void testSingleGroupSum() throws DbException, InterruptedException { final TupleBatchBuffer testBase = generateRandomTuples(numTuples); // group by name, aggregate on id - final SingleGroupByAggregate agg = - new SingleGroupByAggregate( + final Aggregate agg = + new Aggregate( new BatchTupleSource(testBase), - 1, - new SingleColumnAggregatorFactory(0, AggregationOp.SUM)); + new int[] {1}, + new PrimitiveAggregatorFactory(0, AggregationOp.SUM)); agg.open(TestEnvVars.get()); TupleBatch tb = null; final TupleBatchBuffer result = new TupleBatchBuffer(agg.getSchema()); @@ -556,11 +548,11 @@ public void testSingleGroupStd() throws Exception { double expectedStdev = Math.sqrt(diffSquared / n); /* Group by group, aggregate on value */ - final SingleGroupByAggregate agg = - new SingleGroupByAggregate( + final Aggregate agg = + new Aggregate( new BatchTupleSource(testBase), - 0, - new SingleColumnAggregatorFactory(1, AggregationOp.STDEV)); + new int[] {0}, + new PrimitiveAggregatorFactory(1, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); TupleBatch tb = null; final TupleBatchBuffer result = new TupleBatchBuffer(agg.getSchema()); @@ -604,11 +596,11 @@ public void testMultiGroupSum() throws DbException { // test for grouping at the first and second column // expected all the i / 2 to be sum up - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(3, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(3, AggregationOp.SUM)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -618,11 +610,11 @@ public void testMultiGroupSum() throws DbException { // test for grouping at the first, second and third column // expecting half of i / 2 to be sum up on each group - MultiGroupByAggregate mgaTwo = - new MultiGroupByAggregate( + Aggregate mgaTwo = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1, 2}, - new SingleColumnAggregatorFactory(3, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(3, AggregationOp.SUM)); mgaTwo.open(TestEnvVars.get()); TupleBatch resultTwo = mgaTwo.nextReady(); assertNotNull(result); @@ -654,11 +646,11 @@ public void testMultiGroupAvg() throws DbException { tbb.putLong(3, i / 2); } expected /= numTuples; - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1, 2}, - new SingleColumnAggregatorFactory(3, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(3, AggregationOp.AVG)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -687,11 +679,11 @@ public void testMultiGroupMin() throws DbException { } tbb.putLong(3, i / 2); } - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(3, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(3, AggregationOp.MIN)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -720,11 +712,11 @@ public void testMultiGroupMax() throws DbException { } tbb.putLong(3, i); } - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(3, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(3, AggregationOp.MAX)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -754,11 +746,12 @@ public void testMultiGroupMaxAndMin() throws DbException { } tbb.putLong(3, i); } - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(3, AggregationOp.MAX, AggregationOp.MIN)); + new PrimitiveAggregatorFactory( + 3, new AggregationOp[] {AggregationOp.MAX, AggregationOp.MIN})); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -791,12 +784,12 @@ public void testMultiGroupMaxMultiColumn() throws DbException { } tbb.putLong(3, i); } - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(3, AggregationOp.MAX), - new SingleColumnAggregatorFactory(3, AggregationOp.MIN)); + new PrimitiveAggregatorFactory( + 3, new AggregationOp[] {AggregationOp.MAX, AggregationOp.MIN})); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -826,11 +819,11 @@ public void testMultiGroupCountMultiColumn() throws DbException { } tbb.putLong(3, i); } - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(0, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(0, AggregationOp.COUNT)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -925,9 +918,8 @@ public void testMultiGroupCountHashCollision() throws DbException { HashUtils.hashSubRow(buffer, groupCols, 0), HashUtils.hashSubRow(buffer, groupCols, 1)); BatchTupleSource source = new BatchTupleSource(buffer.finalResult()); - MultiGroupByAggregate mga = - new MultiGroupByAggregate( - source, groupCols, new SingleColumnAggregatorFactory(1, AggregationOp.COUNT)); + Aggregate mga = + new Aggregate(source, groupCols, new PrimitiveAggregatorFactory(1, AggregationOp.COUNT)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNotNull(result); @@ -956,11 +948,11 @@ public void testMultiGroupCountMultiColumnEmpty() throws DbException { ImmutableList.of("a", "b", "c", "d")); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); - MultiGroupByAggregate mga = - new MultiGroupByAggregate( + Aggregate mga = + new Aggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(0, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(0, AggregationOp.COUNT)); mga.open(TestEnvVars.get()); TupleBatch result = mga.nextReady(); assertNull(result); diff --git a/test/edu/washington/escience/myria/operator/StreamingAggTest.java b/test/edu/washington/escience/myria/operator/StreamingAggTest.java index 212d0ac5f..cdc6e34df 100644 --- a/test/edu/washington/escience/myria/operator/StreamingAggTest.java +++ b/test/edu/washington/escience/myria/operator/StreamingAggTest.java @@ -12,10 +12,9 @@ import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; +import edu.washington.escience.myria.operator.agg.Aggregate; import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; -import edu.washington.escience.myria.operator.agg.SingleColumnAggregatorFactory; -import edu.washington.escience.myria.operator.agg.SingleGroupByAggregate; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregatorFactory; import edu.washington.escience.myria.operator.agg.StreamingAggregate; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; @@ -24,7 +23,7 @@ /** * Test cases for {@link StreamingAggregate} class. Source tuples are generated in sorted order on group keys, if any. - * Some of the tests are taken from those for {@link SingleGroupByAggregate} and {@link MultiGroupByAggregate} since + * Some of the tests are taken from those for {@link SingleGroupByAggregate} and {@link Aggregate} since * StreamingAggregate is expected to behave the same way they do if input is sorted. */ public class StreamingAggTest { @@ -73,10 +72,8 @@ private TupleBatchBuffer fillInputTbb(final int numTuples) { @Test public void testSingleGroupKeySingleColumnCount() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over). - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col0 @@ -84,7 +81,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {0}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -100,7 +97,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {1}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -116,7 +113,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {2}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -132,7 +129,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {3}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -148,7 +145,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {4}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -164,7 +161,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {5}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -180,7 +177,7 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {6}, - new SingleColumnAggregatorFactory(7, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(7, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -195,10 +192,8 @@ public void testSingleGroupKeySingleColumnCount() throws DbException { @Test public void testSingleGroupKeySingleColumnSum() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over). - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col0 @@ -206,7 +201,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {0}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -222,7 +217,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {1}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -238,7 +233,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {2}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -254,7 +249,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {3}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -270,7 +265,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {4}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -286,7 +281,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {5}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -302,7 +297,7 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {6}, - new SingleColumnAggregatorFactory(7, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(7, AggregationOp.SUM)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -317,10 +312,8 @@ public void testSingleGroupKeySingleColumnSum() throws DbException { @Test public void testSingleGroupKeySingleColumnAvg() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over). - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col0 @@ -328,7 +321,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {0}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -344,7 +337,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {1}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -360,7 +353,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {2}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -376,7 +369,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {3}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -392,7 +385,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {4}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -408,7 +401,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {5}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -424,7 +417,7 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {6}, - new SingleColumnAggregatorFactory(7, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(7, AggregationOp.AVG)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -439,10 +432,8 @@ public void testSingleGroupKeySingleColumnAvg() throws DbException { @Test public void testSingleGroupKeySingleColumnStdev() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over). - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col0 @@ -450,7 +441,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {0}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -466,7 +457,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {1}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -482,7 +473,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {2}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -498,7 +489,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {3}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -514,7 +505,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {4}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -530,7 +521,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {5}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -546,7 +537,7 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {6}, - new SingleColumnAggregatorFactory(7, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(7, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -561,10 +552,8 @@ public void testSingleGroupKeySingleColumnStdev() throws DbException { @Test public void testSingleGroupKeySingleColumnMin() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over) constant value of 2L. - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over) constant value of 2L. */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col7, agg over col0 @@ -572,7 +561,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(0, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(0, AggregationOp.MIN)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -588,7 +577,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(1, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(1, AggregationOp.MIN)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -604,7 +593,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(2, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(2, AggregationOp.MIN)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -620,7 +609,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(3, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(3, AggregationOp.MIN)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -636,7 +625,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(4, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(4, AggregationOp.MIN)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -652,7 +641,7 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(5, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(5, AggregationOp.MIN)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -669,10 +658,8 @@ public void testSingleGroupKeySingleColumnMin() throws DbException { @Test public void testSingleGroupKeySingleColumnMax() throws DbException { final int numTuples = 50; - /* - * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns - * used for grouping) col7: Long (to agg over) constant value of 2L. - */ + /* col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over) constant value of 2L. */ TupleBatchBuffer source = fillInputTbb(numTuples); // group by col7, agg over col0 @@ -680,7 +667,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(0, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(0, AggregationOp.MAX)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -696,7 +683,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(1, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(1, AggregationOp.MAX)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -712,7 +699,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(2, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(2, AggregationOp.MAX)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -728,7 +715,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(3, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(3, AggregationOp.MAX)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -744,7 +731,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(4, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(4, AggregationOp.MAX)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -760,7 +747,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(source), new int[] {7}, - new SingleColumnAggregatorFactory(5, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(5, AggregationOp.MAX)); agg.open(TestEnvVars.get()); result = agg.nextReady(); assertNotNull(result); @@ -790,7 +777,7 @@ public void testMultiGroupSingleColumnCount() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(2, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -818,7 +805,7 @@ public void testMultiGroupSingleColumnMin() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.MIN)); + new PrimitiveAggregatorFactory(2, AggregationOp.MIN)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -848,7 +835,7 @@ public void testMultiGroupSingleColumnMax() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.MAX)); + new PrimitiveAggregatorFactory(2, AggregationOp.MAX)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -878,7 +865,7 @@ public void testMultiGroupSingleColumnSum() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.SUM)); + new PrimitiveAggregatorFactory(2, AggregationOp.SUM)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -906,7 +893,7 @@ public void testMultiGroupSingleColumnAvg() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.AVG)); + new PrimitiveAggregatorFactory(2, AggregationOp.AVG)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -934,7 +921,7 @@ public void testMultiGroupSingleColumnStdev() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory(2, AggregationOp.STDEV)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -961,12 +948,13 @@ public void testSingleGroupKeyMultiColumnAllAgg() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0}, - new SingleColumnAggregatorFactory(0, AggregationOp.MIN), - new SingleColumnAggregatorFactory(0, AggregationOp.MAX), - new SingleColumnAggregatorFactory(1, AggregationOp.COUNT), - new SingleColumnAggregatorFactory(1, AggregationOp.SUM), - new SingleColumnAggregatorFactory(1, AggregationOp.AVG), - new SingleColumnAggregatorFactory(1, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory( + 0, new AggregationOp[] {AggregationOp.MIN, AggregationOp.MAX}), + new PrimitiveAggregatorFactory( + 1, + new AggregationOp[] { + AggregationOp.COUNT, AggregationOp.SUM, AggregationOp.AVG, AggregationOp.STDEV + })); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -1032,12 +1020,16 @@ public void testMultiGroupMultiColumn() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.MIN), - new SingleColumnAggregatorFactory(2, AggregationOp.MAX), - new SingleColumnAggregatorFactory(2, AggregationOp.COUNT), - new SingleColumnAggregatorFactory(2, AggregationOp.SUM), - new SingleColumnAggregatorFactory(2, AggregationOp.AVG), - new SingleColumnAggregatorFactory(2, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory( + 2, + new AggregationOp[] { + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.COUNT, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV + })); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -1093,12 +1085,16 @@ public void testSingleGroupAllAggLargeInput() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0}, - new SingleColumnAggregatorFactory(1, AggregationOp.MIN), - new SingleColumnAggregatorFactory(1, AggregationOp.MAX), - new SingleColumnAggregatorFactory(1, AggregationOp.COUNT), - new SingleColumnAggregatorFactory(1, AggregationOp.SUM), - new SingleColumnAggregatorFactory(1, AggregationOp.AVG), - new SingleColumnAggregatorFactory(1, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory( + 1, + new AggregationOp[] { + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.COUNT, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV + })); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -1184,12 +1180,16 @@ public void testMultiGroupAllAggLargeInput() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0, 1}, - new SingleColumnAggregatorFactory(2, AggregationOp.MIN), - new SingleColumnAggregatorFactory(2, AggregationOp.MAX), - new SingleColumnAggregatorFactory(2, AggregationOp.COUNT), - new SingleColumnAggregatorFactory(2, AggregationOp.SUM), - new SingleColumnAggregatorFactory(2, AggregationOp.AVG), - new SingleColumnAggregatorFactory(2, AggregationOp.STDEV)); + new PrimitiveAggregatorFactory( + 2, + new AggregationOp[] { + AggregationOp.MIN, + AggregationOp.MAX, + AggregationOp.COUNT, + AggregationOp.SUM, + AggregationOp.AVG, + AggregationOp.STDEV + })); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); @@ -1245,7 +1245,7 @@ public void testMultiBatchResult() throws DbException { new StreamingAggregate( new BatchTupleSource(tbb), new int[] {0}, - new SingleColumnAggregatorFactory(1, AggregationOp.COUNT)); + new PrimitiveAggregatorFactory(1, AggregationOp.COUNT)); agg.open(TestEnvVars.get()); TupleBatch result = agg.nextReady(); assertNotNull(result); diff --git a/test/edu/washington/escience/myria/operator/SymmetricHashJoinTest.java b/test/edu/washington/escience/myria/operator/SymmetricHashJoinTest.java index 706947a8f..24c250e3a 100644 --- a/test/edu/washington/escience/myria/operator/SymmetricHashJoinTest.java +++ b/test/edu/washington/escience/myria/operator/SymmetricHashJoinTest.java @@ -16,10 +16,15 @@ public class SymmetricHashJoinTest { public void testSymmetricHashJoin() throws DbException { BatchTupleSource left = new BatchTupleSource(JoinTestUtils.leftInput); BatchTupleSource right = new BatchTupleSource(JoinTestUtils.rightInput); - Operator join = new SymmetricHashJoin(left, right, new int[] {1, 0, 2}, new int[] {2, 1, 0}); + Operator join = + new SymmetricHashJoin( + left, right, new int[] {1, 0, 2}, new int[] {2, 1, 0}, new int[] {0}, new int[] {0}); join.open(TestEnvVars.get()); assertEquals( - Schema.merge(JoinTestUtils.leftSchema, JoinTestUtils.rightSchema), join.getSchema()); + Schema.merge( + JoinTestUtils.leftSchema.getSubSchema(new int[] {0}), + JoinTestUtils.rightSchema.getSubSchema(new int[] {0})), + join.getSchema()); long count = 0; while (!join.eos()) { TupleBatch tb = join.nextReady(); @@ -36,7 +41,9 @@ public void testSymmetricHashJoin() throws DbException { public void testIncompatibleJoinKeys() throws DbException { BatchTupleSource left = new BatchTupleSource(JoinTestUtils.leftInput); BatchTupleSource right = new BatchTupleSource(JoinTestUtils.rightInput); - Operator join = new SymmetricHashJoin(left, right, new int[] {0}, new int[] {0}); + Operator join = + new SymmetricHashJoin( + left, right, new int[] {0}, new int[] {0}, new int[] {0}, new int[] {0}); join.open(TestEnvVars.get()); } } diff --git a/test/edu/washington/escience/myria/operator/UserDefinedAggregatorTest.java b/test/edu/washington/escience/myria/operator/UserDefinedAggregatorTest.java index b36a81343..2de542775 100644 --- a/test/edu/washington/escience/myria/operator/UserDefinedAggregatorTest.java +++ b/test/edu/washington/escience/myria/operator/UserDefinedAggregatorTest.java @@ -68,7 +68,7 @@ public void testCount() throws Exception { new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); - Aggregate agg = new Aggregate(new BatchTupleSource(tbb), factory); + Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; @@ -121,7 +121,7 @@ public void testCountAndConst() throws Exception { new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); - Aggregate agg = new Aggregate(new BatchTupleSource(tbb), factory); + Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; @@ -186,7 +186,7 @@ public void testRowOfMax() throws Exception { new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); - Aggregate agg = new Aggregate(new BatchTupleSource(tbb), factory); + Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; @@ -197,7 +197,7 @@ public void testRowOfMax() throws Exception { assertEquals(2, result.numColumns()); assertEquals(Type.LONG_TYPE, result.getSchema().getColumnType(0)); assertEquals(Type.STRING_TYPE, result.getSchema().getColumnType(1)); - assertEquals(9999, result.getLong(0, 0)); + assertEquals(10000, result.getLong(0, 0)); assertEquals("Foo9999", result.getString(1, 0)); resultSize += result.numTuples(); } diff --git a/test/edu/washington/escience/myria/operator/apply/ApplyDownloadBlobTest.java b/test/edu/washington/escience/myria/operator/apply/ApplyDownloadBlobTest.java index ad7728735..de5f0be1c 100644 --- a/test/edu/washington/escience/myria/operator/apply/ApplyDownloadBlobTest.java +++ b/test/edu/washington/escience/myria/operator/apply/ApplyDownloadBlobTest.java @@ -41,7 +41,6 @@ public void ApplyTest() throws DbException { ImmutableList.Builder Expressions = ImmutableList.builder(); ExpressionOperator filename = new VariableExpression(0); - ; ExpressionOperator db = new DownloadBlobExpression(filename); Expression expr = new Expression("blobs", db); @@ -61,7 +60,6 @@ public void ApplyTest() throws DbException { assertEquals(expectedResultSchema, result.getSchema()); } } - apply.close(); } } diff --git a/test/edu/washington/escience/myria/operator/apply/ApplyTest.java b/test/edu/washington/escience/myria/operator/apply/ApplyTest.java index b89cda6fd..0f866561b 100644 --- a/test/edu/washington/escience/myria/operator/apply/ApplyTest.java +++ b/test/edu/washington/escience/myria/operator/apply/ApplyTest.java @@ -193,7 +193,6 @@ public void testApply() throws DbException { Expression expr = new Expression("copy", vara); GenericEvaluator eval = new GenericEvaluator(expr, parameters); - assertTrue(!eval.needsCompiling()); Expressions.add(expr); } @@ -202,7 +201,6 @@ public void testApply() throws DbException { Expression expr = new Expression("constant5", new ConstantExpression(5)); GenericEvaluator eval = new ConstantEvaluator(expr, parameters); - assertTrue(!eval.needsCompiling()); Expressions.add(expr); } @@ -211,7 +209,6 @@ public void testApply() throws DbException { Expression expr = new Expression("constant5f", new ConstantExpression(5.0f)); GenericEvaluator eval = new ConstantEvaluator(expr, parameters); - assertTrue(!eval.needsCompiling()); Expressions.add(expr); } @@ -220,7 +217,6 @@ public void testApply() throws DbException { Expression expr = new Expression("constant5d", new ConstantExpression(5d)); GenericEvaluator eval = new ConstantEvaluator(expr, parameters); - assertTrue(!eval.needsCompiling()); Expressions.add(expr); } @@ -229,7 +225,6 @@ public void testApply() throws DbException { Expression expr = new Expression("random", new RandomExpression()); GenericEvaluator eval = new GenericEvaluator(expr, parameters); - assertTrue(eval.needsCompiling()); Expressions.add(expr); } @@ -240,7 +235,6 @@ public void testApply() throws DbException { "modulo", new ModuloExpression(new VariableExpression(2), new VariableExpression(1))); GenericEvaluator eval = new GenericEvaluator(expr, parameters); - assertTrue(eval.needsCompiling()); Expressions.add(expr); } @@ -253,7 +247,6 @@ public void testApply() throws DbException { new VariableExpression(4), new VariableExpression(0), new VariableExpression(2))); GenericEvaluator eval = new GenericEvaluator(expr, parameters); - assertTrue(eval.needsCompiling()); Expressions.add(expr); } @@ -274,7 +267,6 @@ public void testApply() throws DbException { new VariableExpression(2))); GenericEvaluator eval = new GenericEvaluator(expr, parameters); - assertTrue(eval.needsCompiling()); Expressions.add(expr); } @@ -284,8 +276,7 @@ public void testApply() throws DbException { GenericEvaluator eval = new ConstantEvaluator(expr, new ExpressionOperatorParameter(tbb.getSchema(), 42)); - assertTrue(!eval.needsCompiling()); - assertEquals(eval.getJavaExpressionWithAppend(), "result.appendInt(42);"); + assertEquals(eval.getScript(), "result.appendInt(42);"); Expressions.add(expr); }