From 33609eaa0bf0b1e61cfac0b6a60ae68d709f824b Mon Sep 17 00:00:00 2001 From: Yuqing Guo Date: Fri, 15 May 2015 00:37:24 -0700 Subject: [PATCH 01/29] adding files and cleaning up --- .../escience/myria/operator/Limit.java | 34 +- .../operator/agg/StreamingAggregate.java | 277 ++++ .../escience/myria/operator/LimitTest.java | 62 +- .../myria/operator/StreamingAggTest.java | 1113 +++++++++++++++++ 4 files changed, 1473 insertions(+), 13 deletions(-) create mode 100644 src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java create mode 100644 test/edu/washington/escience/myria/operator/StreamingAggTest.java diff --git a/src/edu/washington/escience/myria/operator/Limit.java b/src/edu/washington/escience/myria/operator/Limit.java index 03c58aef0..935eea415 100644 --- a/src/edu/washington/escience/myria/operator/Limit.java +++ b/src/edu/washington/escience/myria/operator/Limit.java @@ -12,9 +12,9 @@ import edu.washington.escience.myria.storage.TupleBatch; /** - * A poor implementation of a Limit operator, which emits the first N tuples then drops the rest on the floor. + * A poor implementation of a Limit operator, which emits the first N tuples then closes the child operator from further + * feeding tuples. * - * We would prefer one that stops the incoming stream, but this is not currently supported. */ public final class Limit extends UnaryOperator { @@ -27,6 +27,9 @@ public final class Limit extends UnaryOperator { /** The number of tuples left to emit. */ private long toEmit; + /** If number of emitted tuples reached limit. */ + private boolean done; + /** * A limit operator keeps the first limit tuples produced by its child. * @@ -38,23 +41,30 @@ public Limit(@Nonnull final Long limit, final Operator child) { this.limit = Objects.requireNonNull(limit, "limit"); Preconditions.checkArgument(limit >= 0L, "limit must be non-negative"); toEmit = this.limit; + done = false; } @Override protected TupleBatch fetchNextReady() throws DbException { Operator child = getChild(); - for (TupleBatch tb = child.nextReady(); tb != null; tb = child.nextReady()) { - if (tb.numTuples() <= toEmit) { - toEmit -= tb.numTuples(); - return tb; - } else if (toEmit > 0) { - tb = tb.prefix(Ints.checkedCast(toEmit)); - toEmit = 0; - return tb; + if (done) { + return null; + } else { + TupleBatch tb = child.nextReady(); + TupleBatch result = null; + if (tb != null) { + if (tb.numTuples() <= toEmit) { + toEmit -= tb.numTuples(); + result = tb; + } else if (toEmit > 0) { + result = tb.prefix(Ints.checkedCast(toEmit)); + toEmit = 0; + child.close(); + done = true; + } } - /* Else, drop on the floor. */ + return result; } - return null; } @Override diff --git a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java new file mode 100644 index 000000000..65875e650 --- /dev/null +++ b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java @@ -0,0 +1,277 @@ +package edu.washington.escience.myria.operator.agg; + +import java.util.Objects; + +import javax.annotation.Nullable; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.operator.Operator; +import edu.washington.escience.myria.operator.UnaryOperator; +import edu.washington.escience.myria.storage.Tuple; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBuffer; +import edu.washington.escience.myria.storage.TupleUtils; + +/** + * This aggregate operator computes the aggregation in streaming manner (requires input sorted on grouping column(s)). + * This supports aggregation over multiple columns, with one or more group by columns. Intend to substitute + * SingleGroupByAggregate and MultiGroupByAggregate when input is known to be sorted. + * + * @see Aggregate + * @see SingleGroupByAggregate + * @see MultiGroupByAggregate + */ +public class StreamingAggregate extends UnaryOperator { + + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** The schema of the aggregation result. */ + private Schema aggSchema; + /** The schema of the columns indicated by the group keys. */ + private Schema groupSchema; + /** Holds the current grouping key. */ + private Tuple curGroupKey; + + /** Group fields. **/ + private final int[] gFields; + /** Group field types. **/ + private final Type[] gTypes; + /** An array [0, 1, .., gFields.length-1] used for comparing tuples. */ + private final int[] gRange; + /** Factories to make the Aggregators. **/ + private final AggregatorFactory[] factories; + /** The actual Aggregators. **/ + private Aggregator[] aggregators; + /** The state of the aggregators. */ + private Object[] aggregatorStates; + + /** Buffer for holding intermediate results. */ + private transient TupleBuffer resultBuffer; + /** Buffer for holding finished results as tuple batches. */ + private transient ImmutableList finalBuffer; + + /** + * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. + * + * @param child The Operator that is feeding us tuples. + * @param gfields The columns over which we are grouping the result. + * @param factories The factories that will produce the {@link Aggregator}s for each group. + */ + public StreamingAggregate(@Nullable final Operator child, final int[] gfields, final AggregatorFactory... factories) { + super(child); + gFields = Objects.requireNonNull(gfields, "gfields"); + gTypes = new Type[gfields.length]; + this.factories = Objects.requireNonNull(factories, "factories"); + Preconditions.checkArgument(gfields.length > 0, " must have at least one group by field"); + Preconditions.checkArgument(factories.length != 0, "to use StreamingAggregate, must specify some aggregates"); + gRange = new int[gfields.length]; + for (int i = 0; i < gfields.length; ++i) { + gRange[i] = i; + } + } + + /** + * Returns the next tuple batch containing the result of this aggregate. Grouping field(s) followed by aggregate + * field(s). + * + * @throws DbException if any error occurs. + * @return result tuple batch + */ + @Override + protected TupleBatch fetchNextReady() throws DbException { + final Operator child = getChild(); + if (child.eos()) { + return getResultBatch(); + } + + TupleBatch tb = child.nextReady(); + while (tb != null) { + for (int row = 0; row < tb.numTuples(); ++row) { + if (curGroupKey == null) { + /* + * first time accessing this tb, no aggregation performed previously + */ + // store current group key as a tuple + curGroupKey = new Tuple(groupSchema); + for (int gKey = 0; gKey < gFields.length; ++gKey) { + gTypes[gKey] = tb.getSchema().getColumnType(gFields[gKey]); + switch (gTypes[gKey]) { + case BOOLEAN_TYPE: + curGroupKey.set(gKey, tb.getBoolean(gFields[gKey], row)); + break; + case STRING_TYPE: + curGroupKey.set(gKey, tb.getString(gFields[gKey], row)); + break; + case DATETIME_TYPE: + curGroupKey.set(gKey, tb.getDateTime(gFields[gKey], row)); + break; + case INT_TYPE: + curGroupKey.set(gKey, tb.getInt(gFields[gKey], row)); + break; + case LONG_TYPE: + curGroupKey.set(gKey, tb.getLong(gFields[gKey], row)); + break; + case FLOAT_TYPE: + curGroupKey.set(gKey, tb.getFloat(gFields[gKey], row)); + break; + case DOUBLE_TYPE: + curGroupKey.set(gKey, tb.getDouble(gFields[gKey], row)); + break; + } + } + } else if (!TupleUtils.tupleEquals(tb, gFields, row, curGroupKey, gRange, 0)) { + /* + * different grouping key than current one, flush current agg result to result buffer + */ + addToResult(); + // store current group key as a tuple + for (int gKey = 0; gKey < gFields.length; ++gKey) { + switch (gTypes[gKey]) { + case BOOLEAN_TYPE: + curGroupKey.set(gKey, tb.getBoolean(gFields[gKey], row)); + break; + case STRING_TYPE: + curGroupKey.set(gKey, tb.getString(gFields[gKey], row)); + break; + case DATETIME_TYPE: + curGroupKey.set(gKey, tb.getDateTime(gFields[gKey], row)); + break; + case INT_TYPE: + curGroupKey.set(gKey, tb.getInt(gFields[gKey], row)); + break; + case LONG_TYPE: + curGroupKey.set(gKey, tb.getLong(gFields[gKey], row)); + break; + case FLOAT_TYPE: + curGroupKey.set(gKey, tb.getFloat(gFields[gKey], row)); + break; + case DOUBLE_TYPE: + curGroupKey.set(gKey, tb.getDouble(gFields[gKey], row)); + break; + } + } + reinitializeAggStates(); + } + // update aggregator states with current tuple + for (int agg = 0; agg < aggregators.length; ++agg) { + aggregators[agg].addRow(tb, row, aggregatorStates[agg]); + } + } + tb = child.nextReady(); + } + + /* + * We know that child.nextReady() has returned null, so we have processed all tuple we can. Child is + * either EOS or we have to wait for more data. + */ + if (child.eos()) { + addToResult(); + return getResultBatch(); + } + return null; + } + + /** + * Re-initialize aggregator states for new group key. + * + * @throws DbException if any error + */ + private void reinitializeAggStates() throws DbException { + aggregatorStates = null; + aggregatorStates = AggUtils.allocateAggStates(aggregators); + } + + /** + * Add aggregate results with previous grouping key to result buffer. + * + * @throws DbException if any error + */ + private void addToResult() throws DbException { + int fromIndex = 0; + for (; fromIndex < curGroupKey.numColumns(); ++fromIndex) { + TupleUtils.copyValue(curGroupKey, fromIndex, 0, resultBuffer, fromIndex); + } + for (int agg = 0; agg < aggregators.length; ++agg) { + aggregators[agg].getResult(resultBuffer, fromIndex, aggregatorStates[agg]); + fromIndex += aggregators[agg].getResultSchema().numColumns(); + } + } + + /** + * @return A batch's worth of result tuples from this aggregate. + * @throws DbException if there is an error. + */ + private TupleBatch getResultBatch() throws DbException { + Preconditions.checkState(getChild().eos(), "cannot extract results from an aggregate until child has reached EOS"); + if (finalBuffer == null) { + finalBuffer = resultBuffer.finalResult(); + if (resultBuffer.numTuples() == 0) { + throw new DbException("0 tuples in result buffer"); + } + resultBuffer = null; + } + if (finalBuffer.isEmpty()) { + return null; + } else { + return finalBuffer.get(0); + } + } + + /** + * The schema of the aggregate output. Grouping fields first and then aggregate fields. + * + * @return the resulting schema + */ + @Override + protected Schema generateSchema() { + Operator child = getChild(); + if (child == null) { + return null; + } + Schema inputSchema = child.getSchema(); + if (inputSchema == null) { + return null; + } + + groupSchema = inputSchema.getSubSchema(gFields); + + /* Build the output schema from the group schema and the aggregates. */ + final ImmutableList.Builder aggTypes = ImmutableList. builder(); + final ImmutableList.Builder aggNames = ImmutableList. builder(); + + try { + for (Aggregator agg : AggUtils.allocateAggs(factories, inputSchema)) { + Schema curAggSchema = agg.getResultSchema(); + aggTypes.addAll(curAggSchema.getColumnTypes()); + aggNames.addAll(curAggSchema.getColumnNames()); + } + } catch (DbException e) { + throw new RuntimeException("unable to allocate aggregators to determine output schema", e); + } + aggSchema = new Schema(aggTypes, aggNames); + return Schema.merge(groupSchema, aggSchema); + } + + @Override + protected void init(final ImmutableMap execEnvVars) throws DbException { + Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); + aggregators = AggUtils.allocateAggs(factories, getChild().getSchema()); + aggregatorStates = AggUtils.allocateAggStates(aggregators); + resultBuffer = new TupleBuffer(getSchema()); + } + + @Override + protected void cleanup() throws DbException { + aggregatorStates = null; + curGroupKey = null; + resultBuffer = null; + finalBuffer = null; + } +} diff --git a/test/edu/washington/escience/myria/operator/LimitTest.java b/test/edu/washington/escience/myria/operator/LimitTest.java index d53dafd4b..fe3ad9cf2 100644 --- a/test/edu/washington/escience/myria/operator/LimitTest.java +++ b/test/edu/washington/escience/myria/operator/LimitTest.java @@ -11,6 +11,66 @@ import edu.washington.escience.myria.util.TestUtils; public class LimitTest { + + @Test + public void testWithinBatchSizeLimit() throws DbException { + final int total = TupleBatch.BATCH_SIZE; + final long limit = 100; + assertTrue(limit < total); + TupleSource source = new TupleSource(TestUtils.range(total)); + Limit limiter = new Limit(limit, source); + limiter.open(TestEnvVars.get()); + long count = 0; + while (!limiter.eos()) { + TupleBatch tb = limiter.nextReady(); + if (tb == null) { + continue; + } + count += tb.numTuples(); + } + limiter.close(); + assertEquals(limit, count); + } + + @Test + public void testLimitZero() throws DbException { + final int total = 2 * TupleBatch.BATCH_SIZE + 2; + final long limit = 0; + assertTrue(limit < total); + TupleSource source = new TupleSource(TestUtils.range(total)); + Limit limiter = new Limit(limit, source); + limiter.open(TestEnvVars.get()); + long count = 0; + while (!limiter.eos()) { + TupleBatch tb = limiter.nextReady(); + if (tb == null) { + continue; + } + count += tb.numTuples(); + } + limiter.close(); + assertEquals(limit, count); + } + + @Test + public void testLimitNumTuples() throws DbException { + final int total = 2 * TupleBatch.BATCH_SIZE + 2; + final long limit = total; + TupleSource source = new TupleSource(TestUtils.range(total)); + Limit limiter = new Limit(limit, source); + limiter.open(TestEnvVars.get()); + long count = 0; + while (!limiter.eos()) { + TupleBatch tb = limiter.nextReady(); + if (tb == null) { + continue; + } + count += tb.numTuples(); + } + limiter.close(); + assertEquals(limit, count); + } + @Test public void testSimplePrefix() throws DbException { final int total = 2 * TupleBatch.BATCH_SIZE + 2; @@ -30,4 +90,4 @@ public void testSimplePrefix() throws DbException { limiter.close(); assertEquals(limit, count); } -} +} \ No newline at end of file diff --git a/test/edu/washington/escience/myria/operator/StreamingAggTest.java b/test/edu/washington/escience/myria/operator/StreamingAggTest.java new file mode 100644 index 000000000..a512eeef7 --- /dev/null +++ b/test/edu/washington/escience/myria/operator/StreamingAggTest.java @@ -0,0 +1,1113 @@ +package edu.washington.escience.myria.operator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import org.joda.time.DateTime; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.operator.agg.MultiGroupByAggregate; +import edu.washington.escience.myria.operator.agg.PrimitiveAggregator.AggregationOp; +import edu.washington.escience.myria.operator.agg.SingleColumnAggregatorFactory; +import edu.washington.escience.myria.operator.agg.SingleGroupByAggregate; +import edu.washington.escience.myria.operator.agg.StreamingAggregate; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; + +/** + * Test cases for {@link StreamingAggregate} class. Source tuples are generated in sorted order on group keys, if any. + * Some of the tests are taken from those for {@link SingleGroupByAggregate} and {@link MultiGroupByAggregate} since + * StreamingAggregate is expected to behave the same way they do if input is sorted. + */ +public class StreamingAggTest { + + /** + * Construct a TupleBatchBuffer to be used as source of aggregate. Fixed schema and sorted on grouping columns. + * + * @param numTuples number of tuples to be added + * @return filled TupleBatchBuffer with each group key having (numTuples/10) tuples + */ + private TupleBatchBuffer fillInputTbb(final int numTuples) { + final Schema schema = + new Schema(ImmutableList.of(Type.INT_TYPE, Type.DOUBLE_TYPE, Type.FLOAT_TYPE, Type.LONG_TYPE, + Type.DATETIME_TYPE, Type.STRING_TYPE, Type.BOOLEAN_TYPE, Type.LONG_TYPE), ImmutableList.of("Int", "Double", + "Float", "Long", "Datetime", "String", "Boolean", "value")); + + final TupleBatchBuffer source = new TupleBatchBuffer(schema); + for (int i = 0; i < numTuples; i++) { + int value = i / (numTuples / 10); + source.putInt(0, value); + source.putDouble(1, value); + source.putFloat(2, value); + source.putLong(3, value); + source.putDateTime(4, new DateTime(2010 + value, 1, 1, 0, 0)); + source.putString(5, "" + value); + source.putBoolean(6, (i / (numTuples / 2) == 0)); + source.putLong(7, 2L); + } + return source; + } + + @Test + public void testSingleGroupKeySingleColumnCount() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 0 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 1 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 2 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 3 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 4 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 5 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(5, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col6 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 6 }, new SingleColumnAggregatorFactory(7, + AggregationOp.COUNT)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(2, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(25, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + } + + @Test + public void testSingleGroupKeySingleColumnSum() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 0 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 1 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 2 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 3 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 4 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 5 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col6 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 6 }, new SingleColumnAggregatorFactory(7, + AggregationOp.SUM)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(2, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(50L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + } + + @Test + public void testSingleGroupKeySingleColumnAvg() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 0 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 1 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 2 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 3 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 4 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 5 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col6 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 6 }, new SingleColumnAggregatorFactory(7, + AggregationOp.AVG)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(2, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(2L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + } + + @Test + public void testSingleGroupKeySingleColumnStdev() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over). + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 0 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 1 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 2 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 3 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 4 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 5 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(numTuples / (numTuples / 10), result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col6 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 6 }, new SingleColumnAggregatorFactory(7, + AggregationOp.STDEV)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(2, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + } + + @Test + public void testSingleGroupKeySingleColumnMin() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over) constant value of 2L. + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col7, agg over col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(0, + AggregationOp.MIN)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0, result.getInt(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col7, agg over col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(1, + AggregationOp.MIN)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col7, agg over col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MIN)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0, result.getFloat(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col7, agg over col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(3, + AggregationOp.MIN)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col4, agg over col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(4, + AggregationOp.MIN)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(new DateTime(2010, 1, 1, 0, 0), result.getDateTime(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col7, agg over col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(5, + AggregationOp.MIN)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals("0", result.getString(result.numColumns() - 1, i)); + } + agg.close(); + + // Note: Min not applicable to Boolean type + } + + @Test + public void testSingleGroupKeySingleColumnMax() throws DbException { + final int numTuples = 50; + /* + * col0: Int, col1: Double, col2: Float, col3: Long, col4: Datetime, col5: String, col6: Boolean, (first 7 columns + * used for grouping) col7: Long (to agg over) constant value of 2L. + */ + TupleBatchBuffer source = fillInputTbb(numTuples); + + // group by col7, agg over col0 + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(0, + AggregationOp.MAX)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(9, result.getInt(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col7, agg over col1 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(1, + AggregationOp.MAX)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(9, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col7, agg over col2 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MAX)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(9, result.getFloat(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + + // group by col7, agg over col3 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(3, + AggregationOp.MAX)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(9L, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col4, agg over col4 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(4, + AggregationOp.MAX)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(new DateTime(2019, 1, 1, 0, 0), result.getDateTime(result.numColumns() - 1, i)); + } + agg.close(); + + // group by col7, agg over col5 + agg = + new StreamingAggregate(new TupleSource(source), new int[] { 7 }, new SingleColumnAggregatorFactory(5, + AggregationOp.MAX)); + agg.open(null); + result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals("9", result.getString(result.numColumns() - 1, i)); + } + agg.close(); + + // Note: Max not applicable to type Boolean + } + + @Test + public void testMultiGroupSingleColumnCount() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets i + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, i); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.COUNT)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + } + + @Test + public void testMultiGroupSingleColumnMin() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets i + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, i); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MIN)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + assertEquals(0, result.getLong(result.numColumns() - 1, 0)); + assertEquals(10, result.getLong(result.numColumns() - 1, 1)); + assertEquals(20, result.getLong(result.numColumns() - 1, 2)); + assertEquals(30, result.getLong(result.numColumns() - 1, 3)); + assertEquals(40, result.getLong(result.numColumns() - 1, 4)); + agg.close(); + } + + @Test + public void testMultiGroupSingleColumnMax() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets i + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, i); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MAX)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + assertEquals(9, result.getLong(result.numColumns() - 1, 0)); + assertEquals(19, result.getLong(result.numColumns() - 1, 1)); + assertEquals(29, result.getLong(result.numColumns() - 1, 2)); + assertEquals(39, result.getLong(result.numColumns() - 1, 3)); + assertEquals(49, result.getLong(result.numColumns() - 1, 4)); + agg.close(); + } + + @Test + public void testMultiGroupSingleColumnSum() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, 10L); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.SUM)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(100, result.getLong(result.numColumns() - 1, i)); + } + agg.close(); + } + + @Test + public void testMultiGroupSingleColumnAvg() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, 10L); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.AVG)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(10, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + } + + @Test + public void testMultiGroupSingleColumnStdev() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", + "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i / (numTuples / 5)); + tbb.putLong(2, 10L); + } + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(3, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + assertEquals(0, result.getDouble(result.numColumns() - 1, i), 0.0001); + } + agg.close(); + } + + @Test + public void testSingleGroupKeyMultiColumnAllAgg() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("gkey", "value")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // gkey split to 5 groups, value gets 10 + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, i / (numTuples / 5)); + tbb.putLong(1, 10L); + } + // group by gkey; min on gkey, max on gkey, count on value, sum on value, avg on value, stdev on value + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0 }, new SingleColumnAggregatorFactory(0, + AggregationOp.MIN), new SingleColumnAggregatorFactory(0, AggregationOp.MAX), + new SingleColumnAggregatorFactory(1, AggregationOp.COUNT), new SingleColumnAggregatorFactory(1, + AggregationOp.SUM), new SingleColumnAggregatorFactory(1, AggregationOp.AVG), + new SingleColumnAggregatorFactory(1, AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(5, result.numTuples()); + assertEquals(7, result.getSchema().numColumns()); + for (int i = 0; i < result.numTuples(); i++) { + // min + assertEquals(result.getLong(0, i), result.getLong(1, i)); + // max + assertEquals(result.getLong(0, i), result.getLong(2, i)); + // count + assertEquals(10, result.getLong(3, i)); + // sum + assertEquals(100, result.getLong(4, i)); + // avg + assertEquals(10, result.getDouble(5, i), 0.0001); + // stdev + assertEquals(0, result.getDouble(6, i), 0.0001); + } + agg.close(); + } + + @Test + public void testMultiGroupMultiColumn() throws DbException { + final int numTuples = 50; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList + .of("g0", "g1", "val")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // {0, 2, i} on first half tuples, {0, 4, i} on the second half + int sumFirst = 0; + int sumSecond = 0; + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + if (i / (numTuples / 2) == 0) { + tbb.putLong(1, 2L); + sumFirst += i; + } else { + tbb.putLong(1, 4L); + sumSecond += i; + } + tbb.putLong(2, i); + } + + /* Generate expected values for mean and stdev */ + double meanFirst = (double) sumFirst / (numTuples / 2); + double meanSecond = (double) sumSecond / (numTuples / 2); + double diffSquaredFirst = 0.0; + double diffSquaredSecond = 0.0; + for (int i = 0; i < numTuples; ++i) { + if (i / (numTuples / 2) == 0) { + double diff = i - meanFirst; + diffSquaredFirst += diff * diff; + } else { + double diff = i - meanSecond; + diffSquaredSecond += diff * diff; + } + } + double expectedFirstStdev = Math.sqrt(diffSquaredFirst / (numTuples / 2)); + double expectedSecondStdev = Math.sqrt(diffSquaredSecond / (numTuples / 2)); + + // group by col0 and col1, then min max count sum avg stdev + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MIN), new SingleColumnAggregatorFactory(2, AggregationOp.MAX), + new SingleColumnAggregatorFactory(2, AggregationOp.COUNT), new SingleColumnAggregatorFactory(2, + AggregationOp.SUM), new SingleColumnAggregatorFactory(2, AggregationOp.AVG), + new SingleColumnAggregatorFactory(2, AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(2, result.numTuples()); + assertEquals(8, result.getSchema().numColumns()); + // min + assertEquals(0, result.getLong(2, 0)); + assertEquals(25, result.getLong(2, 1)); + // max + assertEquals(24, result.getLong(3, 0)); + assertEquals(49, result.getLong(3, 1)); + // count + assertEquals(numTuples / 2, result.getLong(4, 0)); + assertEquals(numTuples / 2, result.getLong(4, 1)); + // sum + assertEquals(sumFirst, result.getLong(5, 0)); + assertEquals(sumSecond, result.getLong(5, 1)); + // avg + assertEquals(meanFirst, result.getDouble(6, 0), 0.0001); + assertEquals(meanSecond, result.getDouble(6, 1), 0.0001); + // stdev + assertEquals(expectedFirstStdev, result.getDouble(7, 0), 0.0001); + assertEquals(expectedSecondStdev, result.getDouble(7, 1), 0.0001); + agg.close(); + } + + @Test + public void testSingleGroupAllAggLargeInput() throws DbException { + final int numTuples = 2 * TupleBatch.BATCH_SIZE; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("gkey", "value")); + + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // {0, i} + int sum = 0; + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + tbb.putLong(1, i); + sum += i; + } + + /* Generate expected values for mean and stdev */ + double mean = (double) sum / numTuples; + double diffSquared = 0.0; + for (int i = 0; i < numTuples; ++i) { + double diff = i - mean; + diffSquared += diff * diff; + } + double expectedStdev = Math.sqrt(diffSquared / numTuples); + + // group by gkey, then min max count sum avg stdev + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0 }, new SingleColumnAggregatorFactory(1, + AggregationOp.MIN), new SingleColumnAggregatorFactory(1, AggregationOp.MAX), + new SingleColumnAggregatorFactory(1, AggregationOp.COUNT), new SingleColumnAggregatorFactory(1, + AggregationOp.SUM), new SingleColumnAggregatorFactory(1, AggregationOp.AVG), + new SingleColumnAggregatorFactory(1, AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(1, result.numTuples()); + assertEquals(7, result.getSchema().numColumns()); + // min + assertEquals(0, result.getLong(1, 0)); + // max + assertEquals(19999, result.getLong(2, 0)); + // count + assertEquals(numTuples, result.getLong(3, 0)); + // sum + assertEquals(sum, result.getLong(4, 0)); + // avg + assertEquals(mean, result.getDouble(5, 0), 0.0001); + // stdev + assertEquals(expectedStdev, result.getDouble(6, 0), 0.0001); + agg.close(); + } + + @Test + public void testMultiGroupAllAggLargeInput() throws DbException { + final int numTuples = 3 * TupleBatch.BATCH_SIZE; + final Schema schema = + new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList + .of("g0", "g1", "val")); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // split into 4 groups, each group may spread across different batches + // {0, 0, i} in first group, {0, 1, i} in second, {0, 2, i} in third, {0, 3, i} in fourth + int sumFirst = 0; + int sumSecond = 0; + int sumThird = 0; + int sumFourth = 0; + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, 0L); + if (i / (numTuples / 4) == 0) { + tbb.putLong(1, 0L); + sumFirst += i; + } else if (i / (numTuples / 4) == 1) { + tbb.putLong(1, 1L); + sumSecond += i; + } else if (i / (numTuples / 4) == 2) { + tbb.putLong(1, 2L); + sumThird += i; + } else { + tbb.putLong(1, 3L); + sumFourth += i; + } + tbb.putLong(2, i); + } + + /* Generate expected values for mean and stdev */ + double meanFirst = (double) sumFirst / (numTuples / 4); + double meanSecond = (double) sumSecond / (numTuples / 4); + double meanThird = (double) sumThird / (numTuples / 4); + double meanFourth = (double) sumFourth / (numTuples / 4); + double diffSquaredFirst = 0.0; + double diffSquaredSecond = 0.0; + double diffSquaredThird = 0.0; + double diffSquaredFourth = 0.0; + for (int i = 0; i < numTuples; ++i) { + if (i / (numTuples / 4) == 0) { + double diff = i - meanFirst; + diffSquaredFirst += diff * diff; + } else if (i / (numTuples / 4) == 1) { + double diff = i - meanSecond; + diffSquaredSecond += diff * diff; + } else if (i / (numTuples / 4) == 2) { + double diff = i - meanThird; + diffSquaredThird += diff * diff; + } else { + double diff = i - meanFourth; + diffSquaredFourth += diff * diff; + } + } + double expectedFirstStdev = Math.sqrt(diffSquaredFirst / (numTuples / 4)); + double expectedSecondStdev = Math.sqrt(diffSquaredSecond / (numTuples / 4)); + double expectedThirdStdev = Math.sqrt(diffSquaredThird / (numTuples / 4)); + double expectedFourthStdev = Math.sqrt(diffSquaredFourth / (numTuples / 4)); + + // group by col0 and col1, then min max count sum avg stdev + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0, 1 }, new SingleColumnAggregatorFactory(2, + AggregationOp.MIN), new SingleColumnAggregatorFactory(2, AggregationOp.MAX), + new SingleColumnAggregatorFactory(2, AggregationOp.COUNT), new SingleColumnAggregatorFactory(2, + AggregationOp.SUM), new SingleColumnAggregatorFactory(2, AggregationOp.AVG), + new SingleColumnAggregatorFactory(2, AggregationOp.STDEV)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(4, result.numTuples()); + assertEquals(8, result.getSchema().numColumns()); + // min + assertEquals(0, result.getLong(2, 0)); + assertEquals((numTuples / 4), result.getLong(2, 1)); + assertEquals(2 * (numTuples / 4), result.getLong(2, 2)); + assertEquals(3 * (numTuples / 4), result.getLong(2, 3)); + // max + assertEquals((numTuples / 4) - 1, result.getLong(3, 0)); + assertEquals(2 * (numTuples / 4) - 1, result.getLong(3, 1)); + assertEquals(3 * (numTuples / 4) - 1, result.getLong(3, 2)); + assertEquals(numTuples - 1, result.getLong(3, 3)); + // count + assertEquals(numTuples / 4, result.getLong(4, 0)); + assertEquals(numTuples / 4, result.getLong(4, 1)); + assertEquals(numTuples / 4, result.getLong(4, 2)); + assertEquals(numTuples - 3 * (numTuples / 4), result.getLong(4, 3)); + // sum + assertEquals(sumFirst, result.getLong(5, 0)); + assertEquals(sumSecond, result.getLong(5, 1)); + assertEquals(sumThird, result.getLong(5, 2)); + assertEquals(sumFourth, result.getLong(5, 3)); + // avg + assertEquals(meanFirst, result.getDouble(6, 0), 0.0001); + assertEquals(meanSecond, result.getDouble(6, 1), 0.0001); + assertEquals(meanThird, result.getDouble(6, 2), 0.0001); + assertEquals(meanFourth, result.getDouble(6, 3), 0.0001); + // stdev + assertEquals(expectedFirstStdev, result.getDouble(7, 0), 0.0001); + assertEquals(expectedSecondStdev, result.getDouble(7, 1), 0.0001); + assertEquals(expectedThirdStdev, result.getDouble(7, 2), 0.0001); + assertEquals(expectedFourthStdev, result.getDouble(7, 3), 0.0001); + agg.close(); + } +} \ No newline at end of file From f7a9f1043f5b568b545a87da7dde2a0314c6edf2 Mon Sep 17 00:00:00 2001 From: Yuqing Guo Date: Wed, 20 May 2015 22:10:25 -0700 Subject: [PATCH 02/29] fixed toEmit == 0 case in Limit.java --- .../escience/myria/operator/Limit.java | 35 +++++------ .../escience/myria/operator/LimitTest.java | 58 +++++++++---------- .../myria/operator/StreamingAggTest.java | 46 +++++---------- 3 files changed, 55 insertions(+), 84 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/Limit.java b/src/edu/washington/escience/myria/operator/Limit.java index 935eea415..ca9bbd88c 100644 --- a/src/edu/washington/escience/myria/operator/Limit.java +++ b/src/edu/washington/escience/myria/operator/Limit.java @@ -27,9 +27,6 @@ public final class Limit extends UnaryOperator { /** The number of tuples left to emit. */ private long toEmit; - /** If number of emitted tuples reached limit. */ - private boolean done; - /** * A limit operator keeps the first limit tuples produced by its child. * @@ -41,30 +38,28 @@ public Limit(@Nonnull final Long limit, final Operator child) { this.limit = Objects.requireNonNull(limit, "limit"); Preconditions.checkArgument(limit >= 0L, "limit must be non-negative"); toEmit = this.limit; - done = false; } @Override protected TupleBatch fetchNextReady() throws DbException { Operator child = getChild(); - if (done) { - return null; - } else { - TupleBatch tb = child.nextReady(); - TupleBatch result = null; - if (tb != null) { - if (tb.numTuples() <= toEmit) { - toEmit -= tb.numTuples(); - result = tb; - } else if (toEmit > 0) { - result = tb.prefix(Ints.checkedCast(toEmit)); - toEmit = 0; - child.close(); - done = true; - } + TupleBatch tb = child.nextReady(); + TupleBatch result = null; + if (tb != null) { + if (tb.numTuples() <= toEmit) { + toEmit -= tb.numTuples(); + result = tb; + } else if (toEmit > 0) { + result = tb.prefix(Ints.checkedCast(toEmit)); + toEmit = 0; + } + if (toEmit == 0) { + /* Close child and self. No more stream is needed. */ + child.close(); + close(); } - return result; } + return result; } @Override diff --git a/test/edu/washington/escience/myria/operator/LimitTest.java b/test/edu/washington/escience/myria/operator/LimitTest.java index fe3ad9cf2..29cedf10e 100644 --- a/test/edu/washington/escience/myria/operator/LimitTest.java +++ b/test/edu/washington/escience/myria/operator/LimitTest.java @@ -1,12 +1,16 @@ package edu.washington.escience.myria.operator; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import java.util.LinkedList; + import org.junit.Test; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.util.TestEnvVars; import edu.washington.escience.myria.util.TestUtils; @@ -21,15 +25,11 @@ public void testWithinBatchSizeLimit() throws DbException { Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); long count = 0; - while (!limiter.eos()) { - TupleBatch tb = limiter.nextReady(); - if (tb == null) { - continue; - } - count += tb.numTuples(); - } - limiter.close(); + TupleBatch tb = limiter.nextReady(); + count += tb.numTuples(); assertEquals(limit, count); + assertTrue(limiter.eos()); + /* Limit closes itself as soon as it returned # tuples == limit. */ } @Test @@ -40,35 +40,28 @@ public void testLimitZero() throws DbException { TupleSource source = new TupleSource(TestUtils.range(total)); Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); - long count = 0; - while (!limiter.eos()) { - TupleBatch tb = limiter.nextReady(); - if (tb == null) { - continue; - } - count += tb.numTuples(); - } - limiter.close(); - assertEquals(limit, count); + TupleBatch tb = limiter.nextReady(); + assertNull(tb); + assertTrue(limiter.eos()); + /* Limit closes itself as soon as it returned # tuples == limit. */ } @Test public void testLimitNumTuples() throws DbException { - final int total = 2 * TupleBatch.BATCH_SIZE + 2; + final int total = 100; final long limit = total; - TupleSource source = new TupleSource(TestUtils.range(total)); + TupleBatchBuffer tbb1 = TestUtils.range((int) limit); + TupleBatchBuffer tbb2 = TestUtils.range((int) limit); + LinkedList sourceList = new LinkedList(); + sourceList.add(tbb1.popAny()); + sourceList.add(tbb2.popAny()); + TupleSource source = new TupleSource(sourceList); Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); - long count = 0; - while (!limiter.eos()) { - TupleBatch tb = limiter.nextReady(); - if (tb == null) { - continue; - } - count += tb.numTuples(); - } - limiter.close(); - assertEquals(limit, count); + TupleBatch tb = limiter.nextReady(); + assertEquals(limit, tb.numTuples()); + assertTrue(limiter.eos()); + /* Limit closes itself as soon as it returned # tuples == limit. */ } @Test @@ -80,14 +73,17 @@ public void testSimplePrefix() throws DbException { Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); long count = 0; + int numIteration = 0; while (!limiter.eos()) { TupleBatch tb = limiter.nextReady(); if (tb == null) { continue; } count += tb.numTuples(); + numIteration++; } - limiter.close(); assertEquals(limit, count); + assertEquals(2, numIteration); + /* Limit closes itself as soon as it returned # tuples == limit. */ } } \ No newline at end of file diff --git a/test/edu/washington/escience/myria/operator/StreamingAggTest.java b/test/edu/washington/escience/myria/operator/StreamingAggTest.java index a512eeef7..b31813913 100644 --- a/test/edu/washington/escience/myria/operator/StreamingAggTest.java +++ b/test/edu/washington/escience/myria/operator/StreamingAggTest.java @@ -6,8 +6,6 @@ import org.joda.time.DateTime; import org.junit.Test; -import com.google.common.collect.ImmutableList; - import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; @@ -34,9 +32,9 @@ public class StreamingAggTest { */ private TupleBatchBuffer fillInputTbb(final int numTuples) { final Schema schema = - new Schema(ImmutableList.of(Type.INT_TYPE, Type.DOUBLE_TYPE, Type.FLOAT_TYPE, Type.LONG_TYPE, - Type.DATETIME_TYPE, Type.STRING_TYPE, Type.BOOLEAN_TYPE, Type.LONG_TYPE), ImmutableList.of("Int", "Double", - "Float", "Long", "Datetime", "String", "Boolean", "value")); + Schema.ofFields(Type.INT_TYPE, "Int", Type.DOUBLE_TYPE, "Double", Type.FLOAT_TYPE, "Float", Type.LONG_TYPE, + "Long", Type.DATETIME_TYPE, "Datetime", Type.STRING_TYPE, "String", Type.BOOLEAN_TYPE, "Boolean", + Type.LONG_TYPE, "value"); final TupleBatchBuffer source = new TupleBatchBuffer(schema); for (int i = 0; i < numTuples; i++) { @@ -680,9 +678,7 @@ public void testSingleGroupKeySingleColumnMax() throws DbException { @Test public void testMultiGroupSingleColumnCount() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets i for (long i = 0; i < numTuples; i++) { @@ -707,9 +703,7 @@ public void testMultiGroupSingleColumnCount() throws DbException { @Test public void testMultiGroupSingleColumnMin() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets i for (long i = 0; i < numTuples; i++) { @@ -736,9 +730,7 @@ public void testMultiGroupSingleColumnMin() throws DbException { @Test public void testMultiGroupSingleColumnMax() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets i for (long i = 0; i < numTuples; i++) { @@ -765,9 +757,7 @@ public void testMultiGroupSingleColumnMax() throws DbException { @Test public void testMultiGroupSingleColumnSum() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 for (long i = 0; i < numTuples; i++) { @@ -792,9 +782,7 @@ public void testMultiGroupSingleColumnSum() throws DbException { @Test public void testMultiGroupSingleColumnAvg() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 for (long i = 0; i < numTuples; i++) { @@ -819,9 +807,7 @@ public void testMultiGroupSingleColumnAvg() throws DbException { @Test public void testMultiGroupSingleColumnStdev() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("g0", "g1", - "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // g0 same for all tuples, g1 split to 5 groups, g2 gets 10 for (long i = 0; i < numTuples; i++) { @@ -846,8 +832,7 @@ public void testMultiGroupSingleColumnStdev() throws DbException { @Test public void testSingleGroupKeyMultiColumnAllAgg() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("gkey", "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "gkey", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // gkey split to 5 groups, value gets 10 for (long i = 0; i < numTuples; i++) { @@ -886,9 +871,7 @@ public void testSingleGroupKeyMultiColumnAllAgg() throws DbException { @Test public void testMultiGroupMultiColumn() throws DbException { final int numTuples = 50; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList - .of("g0", "g1", "val")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // {0, 2, i} on first half tuples, {0, 4, i} on the second half int sumFirst = 0; @@ -958,8 +941,7 @@ public void testMultiGroupMultiColumn() throws DbException { @Test public void testSingleGroupAllAggLargeInput() throws DbException { final int numTuples = 2 * TupleBatch.BATCH_SIZE; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList.of("gkey", "value")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "gkey", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // {0, i} @@ -1009,9 +991,7 @@ public void testSingleGroupAllAggLargeInput() throws DbException { @Test public void testMultiGroupAllAggLargeInput() throws DbException { final int numTuples = 3 * TupleBatch.BATCH_SIZE; - final Schema schema = - new Schema(ImmutableList.of(Type.LONG_TYPE, Type.LONG_TYPE, Type.LONG_TYPE), ImmutableList - .of("g0", "g1", "val")); + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "g0", Type.LONG_TYPE, "g1", Type.LONG_TYPE, "value"); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); // split into 4 groups, each group may spread across different batches // {0, 0, i} in first group, {0, 1, i} in second, {0, 2, i} in third, {0, 3, i} in fourth From 3bb49a159d754f85ae393d6ecc590353c203c5c5 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Fri, 22 May 2015 10:35:03 -0700 Subject: [PATCH 03/29] sampling operators --- .../myria/api/encoding/OperatorEncoding.java | 4 + .../myria/api/encoding/SampleEncoding.java | 15 ++ .../encoding/SampledDbInsertTempEncoding.java | 32 +++ .../SamplingDistributionEncoding.java | 21 ++ .../escience/myria/operator/Sample.java | 239 ++++++++++++++++++ .../myria/operator/SampledDbInsertTemp.java | 221 ++++++++++++++++ .../myria/operator/SamplingDistribution.java | 236 +++++++++++++++++ .../network/partition/PartitionFunction.java | 1 + .../partition/RawValuePartitionFunction.java | 56 ++++ 9 files changed, 825 insertions(+) create mode 100644 src/edu/washington/escience/myria/api/encoding/SampleEncoding.java create mode 100644 src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java create mode 100644 src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java create mode 100644 src/edu/washington/escience/myria/operator/Sample.java create mode 100644 src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java create mode 100644 src/edu/washington/escience/myria/operator/SamplingDistribution.java create mode 100644 src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java diff --git a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java index 236773e71..8f2c35f4e 100644 --- a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java @@ -11,6 +11,7 @@ import edu.washington.escience.myria.api.MyriaApiException; import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; import edu.washington.escience.myria.operator.Operator; +import edu.washington.escience.myria.operator.SampledDbInsertTemp; /** * A JSON-able wrapper for the expected wire message for an operator. To add a new operator, two things need to be done. @@ -47,6 +48,9 @@ @Type(name = "NChiladaFileScan", value = NChiladaFileScanEncoding.class), @Type(name = "RightHashCountingJoin", value = RightHashCountingJoinEncoding.class), @Type(name = "RightHashJoin", value = RightHashJoinEncoding.class), + @Type(name = "SampledDbInsertTemp", value = SampledDbInsertTempEncoding.class), + @Type(name = "Sample", value = SampleEncoding.class), + @Type(name = "SamplingDistribution", value = SamplingDistributionEncoding.class), @Type(name = "SeaFlowScan", value = SeaFlowFileScanEncoding.class), @Type(name = "SetGlobal", value = SetGlobalEncoding.class), @Type(name = "ShuffleConsumer", value = ShuffleConsumerEncoding.class), diff --git a/src/edu/washington/escience/myria/api/encoding/SampleEncoding.java b/src/edu/washington/escience/myria/api/encoding/SampleEncoding.java new file mode 100644 index 000000000..cc3290867 --- /dev/null +++ b/src/edu/washington/escience/myria/api/encoding/SampleEncoding.java @@ -0,0 +1,15 @@ +package edu.washington.escience.myria.api.encoding; + +import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; +import edu.washington.escience.myria.operator.Sample; + +public class SampleEncoding extends BinaryOperatorEncoding { + + /** Used to make results deterministic. Null if no specified value. */ + public Long randomSeed; + + @Override + public Sample construct(final ConstructArgs args) { + return new Sample(null, null, randomSeed); + } +} diff --git a/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java b/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java new file mode 100644 index 000000000..2cc10ae85 --- /dev/null +++ b/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java @@ -0,0 +1,32 @@ +package edu.washington.escience.myria.api.encoding; + +import edu.washington.escience.myria.RelationKey; +import edu.washington.escience.myria.accessmethod.ConnectionInfo; +import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; +import edu.washington.escience.myria.operator.SampledDbInsertTemp; + +/** + * Encoding for SampledDbInsertTemp oeprator. + * + */ +public class SampledDbInsertTempEncoding extends UnaryOperatorEncoding { + + @Required + public Integer sampleSize; + @Required + public String sampleTable; + @Required + public String countTable; + /** + * The ConnectionInfo struct determines what database the data will be written + * to. If null, the worker's default database will be used. + */ + public ConnectionInfo connectionInfo; + + @Override + public SampledDbInsertTemp construct(ConstructArgs args) { + return new SampledDbInsertTemp(null, sampleSize, RelationKey.ofTemp( + args.getQueryId(), sampleTable), RelationKey.ofTemp(args.getQueryId(), + countTable), connectionInfo); + } +} diff --git a/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java new file mode 100644 index 000000000..a38385e26 --- /dev/null +++ b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java @@ -0,0 +1,21 @@ +package edu.washington.escience.myria.api.encoding; + +import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; +import edu.washington.escience.myria.operator.SamplingDistribution; + +public class SamplingDistributionEncoding extends UnaryOperatorEncoding { + + @Required + public int sampleSize; + + @Required + public boolean isWithReplacement; + + /** Used to make results deterministic. Null if no specified value. */ + public Long randomSeed; + + @Override + public SamplingDistribution construct(final ConstructArgs args) { + return new SamplingDistribution(null, sampleSize, isWithReplacement, randomSeed); + } +} diff --git a/src/edu/washington/escience/myria/operator/Sample.java b/src/edu/washington/escience/myria/operator/Sample.java new file mode 100644 index 000000000..17e6d7891 --- /dev/null +++ b/src/edu/washington/escience/myria/operator/Sample.java @@ -0,0 +1,239 @@ +package edu.washington.escience.myria.operator; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; + +public class Sample extends BinaryOperator { + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** Total number of tuples to expect from the right operator. */ + private int streamSize; + + /** Number of tuples to sample from the right operator. */ + private int sampleSize; + + /** Random generator used for index selection. */ + private Random rand; + + /** True if the sampling is WithReplacement. WithoutReplacement otherwise. */ + private boolean isWithReplacement; + + /** True if operator has extracted sampling info. */ + private boolean computedSamplingInfo = false; + + /** Buffer for tuples that will be returned. */ + private TupleBatchBuffer ans; + + /** Global count of the tuples seen so far. */ + private int tupleNum = 0; + + /** Sorted array of tuple indices that will be taken as samples. */ + private int[] samples; + + /** Current index of the samples array. */ + private int curSampIdx = 0; + + /** + * Instantiate a Sample operator using sampling info from the left operator + * and the stream from the right operator. + * + * @param left + * inputs a (WorkerID, StreamSize, SampleSize, IsWithReplacement) tuple. + * @param right + * tuples that will be sampled from. + * @param randomSeed + * value to seed the random generator with. null if no specified seed + */ + public Sample(final Operator left, final Operator right, Long randomSeed) { + super(left, right); + this.rand = new Random(); + if (randomSeed != null) { + this.rand.setSeed(randomSeed); + } + } + + @Override + protected TupleBatch fetchNextReady() throws Exception { + // Extract sampling info from left operator. + if (!computedSamplingInfo) { + TupleBatch tb = getLeft().nextReady(); + if (tb == null) + return null; + extractSamplingInfo(tb); + getLeft().close(); + + // Cannot sampleWoR more tuples than there are. + if (!isWithReplacement) { + Preconditions.checkState(sampleSize <= streamSize, + "Cannot SampleWoR %s tuples from a population of size %s", + sampleSize, streamSize); + } + + // Generate target indices to accept as samples. + if (isWithReplacement) { + samples = generateIndicesWR(streamSize, sampleSize); + } else { + samples = generateIndicesWoR(streamSize, sampleSize); + } + + computedSamplingInfo = true; + } + + // Return a ready tuple batch if possible. + TupleBatch nexttb = ans.popAny(); + if (nexttb != null) { + return nexttb; + } + // Check if there's nothing left to sample. + if (curSampIdx >= samples.length) { + getRight().close(); + setEOS(); + return null; + } + Operator right = getRight(); + for (TupleBatch tb = right.nextReady(); tb != null; tb = right.nextReady()) { + if (curSampIdx >= samples.length) { // done sampling + break; + } + if (samples[curSampIdx] > tupleNum + tb.numTuples()) { + // nextIndex is not in this batch. Continue with next batch. + tupleNum += tb.numTuples(); + continue; + } + while (curSampIdx < samples.length + && samples[curSampIdx] < tupleNum + tb.numTuples()) { + ans.put(tb, samples[curSampIdx] - tupleNum); + curSampIdx++; + } + tupleNum += tb.numTuples(); + if (ans.hasFilledTB()) { + return ans.popFilled(); + } + } + return ans.popAny(); + } + + /** Helper function to extract sampling information from a TupleBatch. */ + private void extractSamplingInfo(TupleBatch tb) throws Exception { + Preconditions.checkArgument(tb != null); + + int workerID; + Type col0Type = tb.getSchema().getColumnType(0); + if (col0Type == Type.INT_TYPE) { + workerID = tb.getInt(0, 0); + } else if (col0Type == Type.LONG_TYPE) { + workerID = (int) tb.getLong(0, 0); + } else { + throw new DbException("WorkerID column must be of type INT or LONG"); + } + Preconditions.checkState(workerID == getNodeID(), + "Invalid WorkerID for this worker. Expected %s, but received %s", + getNodeID(), workerID); + + Type col1Type = tb.getSchema().getColumnType(1); + if (col1Type == Type.INT_TYPE) { + streamSize = tb.getInt(1, 0); + } else if (col1Type == Type.LONG_TYPE) { + streamSize = (int) tb.getLong(1, 0); + } else { + throw new DbException("StreamSize column must be of type INT or LONG"); + } + Preconditions.checkState(streamSize >= 0, "streamSize cannot be negative"); + + Type col2Type = tb.getSchema().getColumnType(2); + if (col2Type == Type.INT_TYPE) { + sampleSize = tb.getInt(2, 0); + } else if (col2Type == Type.LONG_TYPE) { + sampleSize = (int) tb.getLong(2, 0); + } else { + throw new DbException("SampleSize column must be of type INT or LONG"); + } + Preconditions.checkState(sampleSize >= 0, "sampleSize cannot be negative"); + + Type col3Type = tb.getSchema().getColumnType(3); + if (col3Type == Type.BOOLEAN_TYPE) { + isWithReplacement= tb.getBoolean(3, 0); + } else { + throw new DbException("IsWithReplacement column must be of type BOOLEAN"); + } + } + + /** + * Generates a sorted array of random numbers to be taken as samples. + * + * @param populationSize + * size of the population that will be sampled from. + * @param sampleSize + * number of samples to draw from the population. + * @return a sorted array of indices. + */ + private int[] generateIndicesWR(int populationSize, int sampleSize) { + int[] indices = new int[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + indices[i] = rand.nextInt(populationSize); + } + Arrays.sort(indices); + return indices; + } + + /** + * Generates a sorted array of unique random numbers to be taken as samples. + * + * @param populationSize + * size of the population that will be sampled from. + * @param sampleSize + * number of samples to draw from the population. + * @return a sorted array of indices. + */ + private int[] generateIndicesWoR(int populationSize, int sampleSize) { + Set indices = new HashSet(sampleSize); + for (int i = populationSize - sampleSize; i < populationSize; i++) { + int idx = rand.nextInt(i + 1); + if (indices.contains(idx)) { + indices.add(i); + } else { + indices.add(idx); + } + } + int[] indicesArr = new int[indices.size()]; + int i = 0; + for (Integer val : indices) { + indicesArr[i] = val; + i++; + } + Arrays.sort(indicesArr); + return indicesArr; + } + + @Override + public Schema generateSchema() { + Operator right = getRight(); + if (right == null) { + return null; + } + return right.getSchema(); + } + + @Override + protected void init(final ImmutableMap execEnvVars) { + ans = new TupleBatchBuffer(getSchema()); + } + + @Override + public void cleanup() { + ans = null; + } + +} diff --git a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java new file mode 100644 index 000000000..eac978f57 --- /dev/null +++ b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java @@ -0,0 +1,221 @@ +/** + * + */ +package edu.washington.escience.myria.operator; + +import java.io.File; +import java.util.*; + +import com.almworks.sqlite4java.SQLiteConnection; +import com.almworks.sqlite4java.SQLiteException; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import edu.washington.escience.myria.*; +import edu.washington.escience.myria.accessmethod.AccessMethod; +import edu.washington.escience.myria.accessmethod.ConnectionInfo; +import edu.washington.escience.myria.accessmethod.SQLiteInfo; +import edu.washington.escience.myria.column.Column; +import edu.washington.escience.myria.column.builder.IntColumnBuilder; +import edu.washington.escience.myria.parallel.RelationWriteMetadata; +import edu.washington.escience.myria.storage.MutableTupleBuffer; +import edu.washington.escience.myria.storage.TupleBatch; + +/** + * Samples the stream into a temp relation. + */ +public class SampledDbInsertTemp extends UnaryOperator implements DbWriter { + + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** The connection to the database database. */ + private AccessMethod accessMethod; + /** The information for the database connection. */ + private ConnectionInfo connectionInfo; + /** The name of the table the tuples should be inserted into. */ + private final RelationKey sampleRelationKey; + /** The name of the table the tuples should be inserted into. */ + private final RelationKey countRelationKey; + + /** Total number of tuples seen from the child. */ + private int tupleCount = 0; + /** Number of tuples to sample from the stream. */ + private final int streamSampleSize; + /** Reservoir that holds sampleSize number of tuples. */ + private MutableTupleBuffer reservoir = null; + /** Sampled tuples ready to be returned. */ + private List batches; + /** Next element of batches List that will be returned. */ + private int batchNum = 0; + /** True if all samples have been gathered from the child. */ + private boolean doneSamplingFromChild; + + /** The output schema. */ + private static final Schema COUNT_SCHEMA = Schema.of( + ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE), + ImmutableList.of("WorkerID", "PartitionSize", "StreamSize")); + + /** + * @param child + * the source of tuples to be inserted. + * @param streamSampleSize + * number of tuples to store from the stream + * @param sampleRelationKey + * the key of the table that the tuples should be inserted into. + * @param countRelationKey + * the key of the table that the tuple counts will be inserted into. + * @param connectionInfo + * the parameters of the database connection. + */ + public SampledDbInsertTemp(final Operator child, final int streamSampleSize, + final RelationKey sampleRelationKey, final RelationKey countRelationKey, + final ConnectionInfo connectionInfo) { + super(child); + // Sampling setup. + Preconditions.checkArgument(streamSampleSize >= 0L, + "sampleSize must be non-negative"); + this.streamSampleSize = streamSampleSize; + doneSamplingFromChild = false; + + // Relation setup. + Objects.requireNonNull(sampleRelationKey, "sampleRelationKey"); + this.sampleRelationKey = sampleRelationKey; + Objects.requireNonNull(countRelationKey, "countRelationKey"); + this.countRelationKey = countRelationKey; + this.connectionInfo = connectionInfo; + } + + @Override + protected TupleBatch fetchNextReady() throws DbException { + if (!doneSamplingFromChild) { + fillReservoir(); + batches = reservoir.getAll(); + // Insert sampled tuples into sampleRelationKey + while (batchNum < batches.size()) { + TupleBatch batch = batches.get(batchNum); + accessMethod.tupleBatchInsert(sampleRelationKey, batch); + batchNum++; + } + + // Write (WorkerID, PartitionSize, StreamSize) to countRelationKey + IntColumnBuilder wIdCol = new IntColumnBuilder(); + IntColumnBuilder tupCountCol = new IntColumnBuilder(); + IntColumnBuilder streamSizeCol = new IntColumnBuilder(); + wIdCol.appendInt(getNodeID()); + tupCountCol.appendInt(tupleCount); + streamSizeCol.appendInt(reservoir.numTuples()); + ImmutableList.Builder> columns = ImmutableList.builder(); + columns.add(wIdCol.build(), tupCountCol.build(), streamSizeCol.build()); + TupleBatch tb = new TupleBatch(COUNT_SCHEMA, columns.build()); + accessMethod.tupleBatchInsert(countRelationKey, tb); + } + return null; + } + + /** + * Fills reservoir with child tuples. + * + * @throws DbException + * if TupleBatch fails to get nextReady + */ + private void fillReservoir() throws DbException { + Random rand = new Random(); + for (TupleBatch tb = getChild().nextReady(); tb != null; tb = getChild() + .nextReady()) { + final List> columns = tb.getDataColumns(); + for (int i = 0; i < tb.numTuples(); i++) { + if (reservoir.numTuples() < streamSampleSize) { + // Reservoir size < k. Add this tuple. + for (int j = 0; j < tb.numColumns(); j++) { + reservoir.put(j, columns.get(j), i); + } + } else { + // Replace probabilistically + int replaceIdx = rand.nextInt(tupleCount); + if (replaceIdx < reservoir.numTuples()) { + for (int j = 0; j < tb.numColumns(); j++) { + reservoir.replace(j, replaceIdx, columns.get(j), i); + } + } + } + tupleCount++; + } + } + doneSamplingFromChild = true; + } + + @Override + protected void init(final ImmutableMap execEnvVars) + throws DbException { + reservoir = new MutableTupleBuffer(getChild().getSchema()); + /* + * retrieve connection information from the environment variables, if not + * already set + */ + if (connectionInfo == null && execEnvVars != null) { + connectionInfo = (ConnectionInfo) execEnvVars + .get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); + } + + if (connectionInfo == null) { + throw new DbException( + "Unable to instantiate SampledDbInsertTemp: connection information unknown"); + } + + if (connectionInfo instanceof SQLiteInfo) { + /* Set WAL in the beginning. */ + final File dbFile = new File( + ((SQLiteInfo) connectionInfo).getDatabaseFilename()); + SQLiteConnection conn = new SQLiteConnection(dbFile); + try { + conn.open(true); + conn.exec("PRAGMA journal_mode=WAL;"); + } catch (SQLiteException e) { + e.printStackTrace(); + } + conn.dispose(); + } + + /* open the database connection */ + accessMethod = AccessMethod.of(connectionInfo.getDbms(), connectionInfo, + false); + accessMethod.dropTableIfExists(sampleRelationKey); + accessMethod.dropTableIfExists(countRelationKey); + // Create the temp tables. + accessMethod.createTableIfNotExists(sampleRelationKey, getSchema()); + accessMethod.createTableIfNotExists(countRelationKey, COUNT_SCHEMA); + } + + @Override + public void cleanup() { + reservoir = null; + batches = null; + + try { + if (accessMethod != null) { + accessMethod.close(); + } + } catch (DbException e) { + throw new RuntimeException(e); + } + } + + @Override + public final Schema generateSchema() { + if (getChild() == null) { + return null; + } + return getChild().getSchema(); + } + + @Override + public Map writeSet() { + Map map = new HashMap(2); + map.put(sampleRelationKey, new RelationWriteMetadata(sampleRelationKey, getSchema(), true, true)); + map.put(countRelationKey, new RelationWriteMetadata(countRelationKey, COUNT_SCHEMA, true, true)); + return map; + } + +} diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java new file mode 100644 index 000000000..5d5e7bc8e --- /dev/null +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -0,0 +1,236 @@ +package edu.washington.escience.myria.operator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.column.Column; +import edu.washington.escience.myria.column.builder.BooleanColumnBuilder; +import edu.washington.escience.myria.column.builder.IntColumnBuilder; +import edu.washington.escience.myria.storage.TupleBatch; + +public class SamplingDistribution extends UnaryOperator { + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** The output schema. */ + private static final Schema SCHEMA = Schema.of( + ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.BOOLEAN_TYPE), + ImmutableList.of("WorkerID", "StreamSize", "SampleSize", "IsWithReplacement")); + + /** Total number of tuples to sample. */ + private final int sampleSize; + + /** True if the sampling is WithReplacement. WithoutReplacement otherwise. */ + private final boolean isWithReplacement; + + /** Random generator used for creating the distribution. */ + private Random rand; + + /** + * Instantiate a SamplingDistribution operator. + * + * @param sampleSize + * total samples to create a distribution for. + * @param isWithReplacement + * true if the distribution uses WithReplacement sampling. + * @param child + * extracts (WorkerID, PartitionSize, StreamSize) information from + * this child. + * @param randomSeed + * value to seed the random generator with. null if no specified seed + */ + public SamplingDistribution(Operator child, int sampleSize, + boolean isWithReplacement, Long randomSeed) { + super(child); + this.sampleSize = sampleSize; + Preconditions.checkState(sampleSize >= 0, + "Sample size cannot be negative: %s", sampleSize); + this.isWithReplacement = isWithReplacement; + this.rand = new Random(); + if (randomSeed != null) { + this.rand.setSeed(randomSeed); + } + } + + @Override + protected TupleBatch fetchNextReady() throws DbException { + if (getChild().eos()) { + return null; + } + + // Distribution of the tuples across the workers. + // Value at index i == # of tuples on worker i. + ArrayList tupleCounts = new ArrayList(); + + // Distribution of the actual stream size across the workers. + // May be different from tupleCounts if worker i pre-sampled the data. + // Value at index i == # of tuples in stream on worker i. + ArrayList streamCounts = new ArrayList(); + + // Total number of tuples across all workers. + int totalTupleCount = 0; + + // Drain out all the workerID and partitionSize info. + while (!getChild().eos()) { + TupleBatch tb = getChild().nextReady(); + if (tb == null) { + continue; + } + Type col0Type = tb.getSchema().getColumnType(0); + Type col1Type = tb.getSchema().getColumnType(1); + boolean hasStreamSize = false; + Type col2Type = null; + if (tb.getSchema().numColumns() > 2) { + hasStreamSize = true; + col2Type = tb.getSchema().getColumnType(2); + } + for (int i = 0; i < tb.numTuples(); i++) { + int workerID; + if (col0Type == Type.INT_TYPE) { + workerID = tb.getInt(0, i); + } else if (col0Type == Type.LONG_TYPE) { + workerID = (int) tb.getLong(0, i); + } else { + throw new DbException("WorkerID must be of type INT or LONG"); + } + Preconditions.checkState(workerID > 0, "WorkerID must be > 0"); + // Ensure the future .set(workerID, -) calls will work. + for (int j = tupleCounts.size(); j < workerID; j++) { + tupleCounts.add(0); + streamCounts.add(0); + } + + int partitionSize; + if (col1Type == Type.INT_TYPE) { + partitionSize = tb.getInt(1, i); + } else if (col1Type == Type.LONG_TYPE) { + partitionSize = (int) tb.getLong(1, i); + } else { + throw new DbException("PartitionSize must be of type INT or LONG"); + } + Preconditions.checkState(partitionSize >= 0, + "Worker cannot have a negative PartitionSize: %s", partitionSize); + tupleCounts.set(workerID - 1, partitionSize); + totalTupleCount += partitionSize; + + int streamSize = partitionSize; + if (hasStreamSize) { + if (col2Type == Type.INT_TYPE) { + streamSize = tb.getInt(2, i); + } else if (col2Type == Type.LONG_TYPE) { + streamSize = (int) tb.getLong(2, i); + } else { + throw new DbException("StreamSize must be of type INT or LONG"); + } + Preconditions.checkState(partitionSize >= 0, + "Worker cannot have a negative StreamSize: %s", streamSize); + } + streamCounts.set(workerID - 1, streamSize); + } + } + Preconditions.checkState(sampleSize <= totalTupleCount, + "Cannot extract %s samples from a population of size %s", sampleSize, + totalTupleCount); + + // Generate a random distribution across the workers. + int[] sampleCounts; + if (isWithReplacement) { + sampleCounts = withReplacementDistribution(tupleCounts, sampleSize); + } else { + sampleCounts = withoutReplacementDistribution(tupleCounts, sampleSize); + } + + // Build and return a TupleBatch with the distribution. + IntColumnBuilder wIdCol = new IntColumnBuilder(); + IntColumnBuilder streamSizeCol = new IntColumnBuilder(); + IntColumnBuilder sampCountCol = new IntColumnBuilder(); + BooleanColumnBuilder wrCol = new BooleanColumnBuilder(); + for (int i = 0; i < streamCounts.size(); i++) { + wIdCol.appendInt(i + 1); + streamSizeCol.appendInt(streamCounts.get(i)); + sampCountCol.appendInt(sampleCounts[i]); + wrCol.appendBoolean(isWithReplacement); + } + ImmutableList.Builder> columns = ImmutableList.builder(); + columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), wrCol.build()); + return new TupleBatch(SCHEMA, columns.build()); + } + + /** + * Creates a WithReplacement distribution across the workers. + * + * @param tupleCounts + * list of how many tuples each worker has. + * @param sampleSize + * total number of samples to distribute across the workers. + * @return array representing the distribution across the workers. + */ + private int[] withReplacementDistribution(List tupleCounts, + int sampleSize) { + int[] distribution = new int[tupleCounts.size()]; + int totalTupleCount = 0; + for (int val : tupleCounts) + totalTupleCount += val; + + for (int i = 0; i < sampleSize; i++) { + int sampleTupleIdx = rand.nextInt(totalTupleCount); + // Assign this tuple to the workerID that holds this sampleTupleIdx. + int tupleOffset = 0; + for (int j = 0; j < tupleCounts.size(); j++) { + if (sampleTupleIdx < tupleCounts.get(j) + tupleOffset) { + distribution[j] += 1; + break; + } + tupleOffset += tupleCounts.get(j); + } + } + return distribution; + } + + /** + * Creates a WithoutReplacement distribution across the workers. + * + * @param tupleCounts + * list of how many tuples each worker has. + * @param sampleSize + * total number of samples to distribute across the workers. + * @return array representing the distribution across the workers. + */ + private int[] withoutReplacementDistribution(List tupleCounts, + int sampleSize) { + int[] distribution = new int[tupleCounts.size()]; + int totalTupleCount = 0; + for (int val : tupleCounts) + totalTupleCount += val; + List logicalTupleCounts = new ArrayList(tupleCounts); + + for (int i = 0; i < sampleSize; i++) { + int sampleTupleIdx = rand.nextInt(totalTupleCount - i); + // Assign this tuple to the workerID that holds this sampleTupleIdx. + int tupleOffset = 0; + for (int j = 0; j < logicalTupleCounts.size(); j++) { + if (sampleTupleIdx < logicalTupleCounts.get(j) + tupleOffset) { + distribution[j] += 1; + // Cannot sample the same tuple, so pretend it doesn't exist anymore. + logicalTupleCounts.set(j, logicalTupleCounts.get(j) - 1); + break; + } + tupleOffset += logicalTupleCounts.get(j); + } + } + return distribution; + } + + @Override + public Schema generateSchema() { + return SCHEMA; + } + +} diff --git a/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java b/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java index 66562ecc2..4147797d1 100644 --- a/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java +++ b/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java @@ -22,6 +22,7 @@ @JsonSubTypes({ @Type(value = RoundRobinPartitionFunction.class, name = "RoundRobin"), @Type(value = SingleFieldHashPartitionFunction.class, name = "SingleFieldHash"), + @Type(value = RawValuePartitionFunction.class, name = "RawValue"), @Type(value = MultiFieldHashPartitionFunction.class, name = "MultiFieldHash"), @Type(value = WholeTupleHashPartitionFunction.class, name = "WholeTupleHash") }) public abstract class PartitionFunction implements Serializable { diff --git a/src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java b/src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java new file mode 100644 index 000000000..fade62ecc --- /dev/null +++ b/src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java @@ -0,0 +1,56 @@ +package edu.washington.escience.myria.operator.network.partition; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; + +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.storage.TupleBatch; + +/** + * Implementation of a PartitionFunction that decides based on the raw values of an integer field. + */ +public final class RawValuePartitionFunction extends PartitionFunction { + + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** The index of the partition field. */ + @JsonProperty + private final int index; + + /** + * @param index the index of the partition field. + */ + @JsonCreator + public RawValuePartitionFunction( + @JsonProperty(value = "index", required = true) final Integer index) { + super(null); + this.index = java.util.Objects.requireNonNull(index, "missing property index"); + Preconditions.checkArgument(this.index >= 0, + "RawValue field index cannot take negative value %s", this.index); + } + + /** + * @return the index + */ + public int getIndex() { + return index; + } + + /** + * @param tb data. + * @return partitions. + * */ + @Override + public int[] partition(final TupleBatch tb) { + Preconditions.checkArgument(tb.getSchema().getColumnType(index) == Type.INT_TYPE, + "RawValue index column must be of type INT"); + final int[] result = new int[tb.numTuples()]; + for (int i = 0; i < result.length; i++) { + // Offset by -1 because WorkerIDs are 1-indexed. + result[i] = tb.getInt(index, i) - 1; + } + return result; + } +} From c03dad19ea2ed3779a49a5750ec205e81e5c7356 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Fri, 22 May 2015 10:40:20 -0700 Subject: [PATCH 04/29] sampling tests --- .../escience/myria/operator/SampleWRTest.java | 155 +++++++++++ .../myria/operator/SampleWoRTest.java | 141 ++++++++++ .../operator/SamplingDistributionTest.java | 243 ++++++++++++++++++ 3 files changed, 539 insertions(+) create mode 100644 test/edu/washington/escience/myria/operator/SampleWRTest.java create mode 100644 test/edu/washington/escience/myria/operator/SampleWoRTest.java create mode 100644 test/edu/washington/escience/myria/operator/SamplingDistributionTest.java diff --git a/test/edu/washington/escience/myria/operator/SampleWRTest.java b/test/edu/washington/escience/myria/operator/SampleWRTest.java new file mode 100644 index 000000000..e053e1846 --- /dev/null +++ b/test/edu/washington/escience/myria/operator/SampleWRTest.java @@ -0,0 +1,155 @@ +package edu.washington.escience.myria.operator; + +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; +import edu.washington.escience.myria.util.TestEnvVars; + +/** + * Tests SampleWR by verifying the results of various scenarios. + */ +public class SampleWRTest { + + final long RANDOM_SEED = 42; + final int[] INPUT_VALS = { 0, 1, 2, 3, 4, 5 }; + // values generated by rand.nextInt(INPUT_VALS.length) w/ seed=42 + final int[] SEED_OUTPUT = { 2, 3, 0, 2, 0, 1, 5, 2, 1, 5, 2, 2 }; + + final Schema LEFT_SCHEMA = Schema.ofFields("WorkerID", Type.INT_TYPE, + "PartitionSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, + "IsWithReplacement", Type.BOOLEAN_TYPE); + final Schema RIGHT_SCHEMA = Schema.ofFields(Type.INT_TYPE, "SomeValue"); + final Schema OUTPUT_SCHEMA = RIGHT_SCHEMA; + + TupleBatchBuffer leftInput; + TupleBatchBuffer rightInput; + Sample sampOp; + + @Before + public void setup() { + leftInput = new TupleBatchBuffer(LEFT_SCHEMA); + leftInput.putInt(0, -1); // WorkerID for testing + rightInput = new TupleBatchBuffer(RIGHT_SCHEMA); + for (int val : INPUT_VALS) { + rightInput.putInt(0, val); + } + } + + /** Sample size 0. */ + @Test + public void testSampleSizeZero() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = 0; + int[] expected = Arrays.copyOf(SEED_OUTPUT, sampleSize); + Arrays.sort(expected); + verifyExpectedResults(partitionSize, sampleSize, expected); + } + + /** Sample size 1. */ + @Test + public void testSampleSizeOne() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = 1; + int[] expected = Arrays.copyOf(SEED_OUTPUT, sampleSize); + Arrays.sort(expected); + verifyExpectedResults(partitionSize, sampleSize, expected); + } + + /** Sample size 50%. */ + @Test + public void testSampleSizeHalf() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length / 2; + int[] expected = Arrays.copyOf(SEED_OUTPUT, sampleSize); + Arrays.sort(expected); + verifyExpectedResults(partitionSize, sampleSize, expected); + } + + /** Sample size all. */ + @Test + public void testSampleSizeAll() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length; + int[] expected = Arrays.copyOf(SEED_OUTPUT, sampleSize); + Arrays.sort(expected); + verifyExpectedResults(partitionSize, sampleSize, expected); + } + + /** Sample size 200%. */ + @Test + public void testSampleSizeDouble() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length * 2; + int[] expected = Arrays.copyOf(SEED_OUTPUT, sampleSize); + Arrays.sort(expected); + verifyExpectedResults(partitionSize, sampleSize, expected); + } + + /** Cannot have a negative sample size. */ + @Test(expected = IllegalStateException.class) + public void testSampleSizeNegative() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = -1; + drainOperator(partitionSize, sampleSize); + } + + /** Cannot have a negative partition size. */ + @Test(expected = IllegalStateException.class) + public void testSamplePartitionNegative() throws DbException { + int partitionSize = -1; + int sampleSize = 3; + drainOperator(partitionSize, sampleSize); + } + + @After + public void cleanup() throws DbException { + if (sampOp != null && sampOp.isOpen()) { + sampOp.close(); + } + } + + /** Tests the correctness of a sampling operation using a seeded value. */ + private void verifyExpectedResults(int partitionSize, int sampleSize, + int[] expected) throws DbException { + leftInput.putInt(1, partitionSize); + leftInput.putInt(2, sampleSize); + leftInput.putBoolean(3, true); + sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + + int rowIdx = 0; + while (!sampOp.eos()) { + TupleBatch result = sampOp.nextReady(); + if (result != null) { + assertEquals(OUTPUT_SCHEMA, result.getSchema()); + for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { + assertEquals(result.getInt(0, i), expected[rowIdx]); + } + } + } + assertEquals(sampleSize, rowIdx); + } + + /** Run through all results without doing anything. */ + private void drainOperator(int partitionSize, int sampleSize) + throws DbException { + leftInput.putInt(1, partitionSize); + leftInput.putInt(2, sampleSize); + leftInput.putBoolean(3, true); + sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + while (!sampOp.eos()) { + sampOp.nextReady(); + } + } +} diff --git a/test/edu/washington/escience/myria/operator/SampleWoRTest.java b/test/edu/washington/escience/myria/operator/SampleWoRTest.java new file mode 100644 index 000000000..0ec6dbe05 --- /dev/null +++ b/test/edu/washington/escience/myria/operator/SampleWoRTest.java @@ -0,0 +1,141 @@ +package edu.washington.escience.myria.operator; + +import static org.junit.Assert.assertEquals; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; +import edu.washington.escience.myria.util.TestEnvVars; + +/** + * Tests SampleWoR by verifying the results of various scenarios. + */ +public class SampleWoRTest { + + final long RANDOM_SEED = 42; + final int[] INPUT_VALS = { 0, 1, 2, 3, 4, 5 }; + + final Schema LEFT_SCHEMA = Schema.ofFields("WorkerID", Type.INT_TYPE, + "PartitionSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, + "IsWithReplacement", Type.BOOLEAN_TYPE); + final Schema RIGHT_SCHEMA = Schema.ofFields(Type.INT_TYPE, "SomeValue"); + final Schema OUTPUT_SCHEMA = RIGHT_SCHEMA; + + TupleBatchBuffer leftInput; + TupleBatchBuffer rightInput; + Sample sampOp; + + @Before + public void setup() { + leftInput = new TupleBatchBuffer(LEFT_SCHEMA); + leftInput.putInt(0, -1); // WorkerID for testing + rightInput = new TupleBatchBuffer(RIGHT_SCHEMA); + for (int val : INPUT_VALS) { + rightInput.putInt(0, val); + } + } + + /** Sample size 0. */ + @Test + public void testSampleSizeZero() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = 0; + verifyExpectedResults(partitionSize, sampleSize); + } + + /** Sample size 1. */ + @Test + public void testSampleSizeOne() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = 1; + verifyExpectedResults(partitionSize, sampleSize); + } + + /** Sample size 50%. */ + @Test + public void testSampleSizeHalf() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length / 2; + verifyExpectedResults(partitionSize, sampleSize); + } + + /** Sample size all. */ + @Test + public void testSampleSizeAll() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length; + verifyExpectedResults(partitionSize, sampleSize); + } + + /** Sample size greater than partition size. */ + @Test(expected = IllegalStateException.class) + public void testSampleSizeTooMany() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = INPUT_VALS.length + 1; + drainOperator(partitionSize, sampleSize); + } + + /** Cannot have a negative sample size. */ + @Test(expected = IllegalStateException.class) + public void testSampleSizeNegative() throws DbException { + int partitionSize = INPUT_VALS.length; + int sampleSize = -1; + drainOperator(partitionSize, sampleSize); + } + + /** Cannot have a negative partition size. */ + @Test(expected = IllegalStateException.class) + public void testSamplePartitionNegative() throws DbException { + int partitionSize = -1; + int sampleSize = 3; + drainOperator(partitionSize, sampleSize); + } + + @After + public void cleanup() throws DbException { + if (sampOp != null && sampOp.isOpen()) { + sampOp.close(); + } + } + + /** + * Tests whether the output could be a valid distribution. Note: doesn't + * currently test for statistical randomness. + */ + private void verifyExpectedResults(int partitionSize, int sampleSize) throws DbException { + leftInput.putInt(1, partitionSize); + leftInput.putInt(2, sampleSize); + leftInput.putBoolean(3, false); + sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + + int rowIdx = 0; + while (!sampOp.eos()) { + TupleBatch result = sampOp.nextReady(); + if (result != null) { + assertEquals(OUTPUT_SCHEMA, result.getSchema()); + rowIdx += result.numTuples(); + } + } + assertEquals(sampleSize, rowIdx); + } + + /** Run through all results without doing anything. */ + private void drainOperator(int partitionSize, int sampleSize) + throws DbException { + leftInput.putInt(1, partitionSize); + leftInput.putInt(2, sampleSize); + leftInput.putBoolean(3, false); + sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + while (!sampOp.eos()) { + sampOp.nextReady(); + } + } +} diff --git a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java new file mode 100644 index 000000000..a80e290c1 --- /dev/null +++ b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java @@ -0,0 +1,243 @@ +package edu.washington.escience.myria.operator; + +import static org.junit.Assert.assertEquals; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.storage.TupleBatchBuffer; +import edu.washington.escience.myria.util.TestEnvVars; + +/** + * Tests the SamplingDistribution operator by verifying the results of various + * scenarios. + */ +public class SamplingDistributionTest { + + final long RANDOM_SEED = 42; + + final Schema inputSchema = Schema.ofFields("WorkerID", Type.INT_TYPE, + "PartitionSize", Type.INT_TYPE); + final Schema expectedResultSchema = Schema.ofFields("WorkerID", Type.INT_TYPE, + "StreamSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, "IsWithReplacement", Type.BOOLEAN_TYPE); + + TupleBatchBuffer input; + SamplingDistribution sampOp; + + @Before + public void setup() { + // (WorkerID, PartitionSize) + input = new TupleBatchBuffer(inputSchema); + input.putInt(0, 1); + input.putInt(1, 300); + input.putInt(0, 2); + input.putInt(1, 200); + input.putInt(0, 3); + input.putInt(1, 400); + input.putInt(0, 4); + input.putInt(1, 100); + } + + /** Sample size 0. */ + @Test + public void testSampleWRSizeZero() throws DbException { + int sampleSize = 0; + boolean isWithReplacement = true; + final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, + { 3, 400, 0 }, { 4, 100, 0 } }; + verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + } + + @Test + public void testSampleWoRSizeZero() throws DbException { + int sampleSize = 0; + boolean isWithReplacement = false; + final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, + { 3, 400, 0 }, { 4, 100, 0 } }; + verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + } + + /** Sample size 1. */ + @Test + public void testSampleWRSizeOne() throws DbException { + int sampleSize = 1; + boolean isWithReplacement = true; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + @Test + public void testSampleWoRSizeOne() throws DbException { + int sampleSize = 1; + boolean isWithReplacement = false; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + /** Sample size 50. */ + @Test + public void testSampleWRSizeFifty() throws DbException { + int sampleSize = 50; + boolean isWithReplacement = true; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + @Test + public void testSampleWoRSizeFifty() throws DbException { + int sampleSize = 50; + boolean isWithReplacement = false; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + /** Sample all but one tuple. */ + @Test + public void testSampleWoRSizeAllButOne() throws DbException { + int sampleSize = 999; + boolean isWithReplacement = false; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + @Test + public void testSampleWRSizeAllButOne() throws DbException { + int sampleSize = 999; + boolean isWithReplacement = true; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + /** SamplingWoR the entire population == return all. */ + @Test + public void testSampleWoRSizeMax() throws DbException { + int sampleSize = 1000; + boolean isWithReplacement = false; + final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, + { 3, 400, 400 }, { 4, 100, 100 } }; + verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + } + + /** SamplingWR the entire population. */ + @Test + public void testSampleWRSizeMax() throws DbException { + int sampleSize = 1000; + boolean isWithReplacement = true; + verifyPossibleDistribution(sampleSize, isWithReplacement); + } + + /** Cannot sample more than total size. */ + @Test(expected = IllegalStateException.class) + public void testSampleWoRSizeTooMany() throws DbException { + int sampleSize = 1001; + boolean isWithReplacement = false; + drainOperator(sampleSize, isWithReplacement); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWRSizeTooMany() throws DbException { + int sampleSize = 1001; + boolean isWithReplacement = true; + drainOperator(sampleSize, isWithReplacement); + } + + /** Cannot sample a negative number of samples. */ + @Test(expected = IllegalStateException.class) + public void testSampleWoRSizeNegative() throws DbException { + int sampleSize = -1; + boolean isWithReplacement = false; + drainOperator(sampleSize, isWithReplacement); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWRSizeNegative() throws DbException { + int sampleSize = -1; + boolean isWithReplacement = false; + drainOperator(sampleSize, isWithReplacement); + } + + /** Worker cannot report a negative partition size. */ + @Test(expected = IllegalStateException.class) + public void testSampleWoRWorkerNegative() throws DbException { + int sampleSize = 50; + boolean isWithReplacement = false; + input.putInt(0, 5); + input.putInt(1, -1); + drainOperator(sampleSize, isWithReplacement); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWRWorkerNegative() throws DbException { + int sampleSize = 50; + boolean isWithReplacement = true; + input.putInt(0, 5); + input.putInt(1, -1); + drainOperator(sampleSize, isWithReplacement); + } + + @After + public void cleanup() throws DbException { + if (sampOp != null && sampOp.isOpen()) { + sampOp.close(); + } + } + + /** Compare output results compared to some known expectedResults. */ + private void verifyExpectedResults(int sampleSize, + boolean isWithReplacement, int[][] expectedResults) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + + int rowIdx = 0; + while (!sampOp.eos()) { + TupleBatch result = sampOp.nextReady(); + if (result != null) { + assertEquals(expectedResultSchema, result.getSchema()); + for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { + assertEquals(expectedResults[rowIdx][0], result.getInt(0, i)); + assertEquals(expectedResults[rowIdx][1], result.getInt(1, i)); + assertEquals(expectedResults[rowIdx][2], result.getInt(2, i)); + } + } + } + assertEquals(expectedResults.length, rowIdx); + } + + /** + * Tests the actual distribution against what could be possible. Note: doesn't + * test if it is statistically random. + */ + private void verifyPossibleDistribution(int sampleSize, + boolean isWithReplacement) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + + int rowIdx = 0; + int computedSampleSize = 0; + while (!sampOp.eos()) { + TupleBatch result = sampOp.nextReady(); + if (result != null) { + assertEquals(expectedResultSchema, result.getSchema()); + for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { + assert (result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampleSize); + if (!isWithReplacement) { + // SampleWoR cannot sample more than worker's population size. + assert (result.getInt(2, i) <= result.getInt(1, i)); + } + computedSampleSize += result.getInt(2, i); + } + } + } + assertEquals(input.numTuples(), rowIdx); + assertEquals(sampleSize, computedSampleSize); + } + + /** Run through all results without doing anything. */ + private void drainOperator(int sampleSize, boolean isWithReplacement) + throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + while (!sampOp.eos()) { + sampOp.nextReady(); + } + } +} From 4c90e6f358f879a15df54d049136aa56489a13e1 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Fri, 22 May 2015 10:41:44 -0700 Subject: [PATCH 05/29] example sampling json plans --- jsonQueries/radion_queries/SampleScanWR.json | 130 +++++++++++++++++ jsonQueries/radion_queries/SampleScanWoR.json | 130 +++++++++++++++++ jsonQueries/radion_queries/SampleWR.json | 133 ++++++++++++++++++ jsonQueries/radion_queries/SampleWoR.json | 133 ++++++++++++++++++ 4 files changed, 526 insertions(+) create mode 100644 jsonQueries/radion_queries/SampleScanWR.json create mode 100644 jsonQueries/radion_queries/SampleScanWoR.json create mode 100644 jsonQueries/radion_queries/SampleWR.json create mode 100644 jsonQueries/radion_queries/SampleWoR.json diff --git a/jsonQueries/radion_queries/SampleScanWR.json b/jsonQueries/radion_queries/SampleScanWR.json new file mode 100644 index 000000000..65257601a --- /dev/null +++ b/jsonQueries/radion_queries/SampleScanWR.json @@ -0,0 +1,130 @@ +{ + "logicalRa": "Nothing", + "plan": + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"DbQueryScan", + "opName":"MyriaQueryScan(public:adhoc:TwitterNodes)", + "opId":3, + "schema":{ + "columnNames": + [ + "cnt" + ], + "columnTypes": + [ + "LONG_TYPE" + ] + }, + "sql":"SELECT count(*) AS cnt FROM \"public:adhoc:TwitterNodes\"" + }, + { + "opType":"Apply", + "argChild":3, + "opId":4, + "opName":"MyriaApply(WorkerID=WORKERID)", + "emitExpressions": + [ + { + "outputName":"WorkerID", + "rootExpressionOperator": + { + "type":"WORKERID" + } + }, + { + "outputName":"WorkerCount", + "rootExpressionOperator": + { + "type":"VARIABLE", + "columnIdx":0 + } + } + ] + }, + { + "argChild":4, + "opId":5, + "opType": "CollectProducer" + } + ] + }, + { + "operators": + [ + { + "argOperatorId":5, + "opId":6, + "opType": "CollectConsumer" + }, + { + "opName": "MyriaSamplingDistribution", + "opType": "SamplingDistribution", + "argChild": 6, + "opId": 7, + "sampleSize": 200, + "isWithReplacement": true + }, + { + "argChild":7, + "argPf":{ + "index":0, + "type":"RawValue" + }, + "opType":"ShuffleProducer", + "opId":8, + "opName":"MyriaShuffleProducer($0)" + } + ] + }, + { + "operators": + [ + { + "opName":"MyriaShuffleConsumer", + "opType":"ShuffleConsumer", + "argOperatorId":8, + "opId":9 + }, + { + "opType":"TableScan", + "opName":"MyriaScan(public:adhoc:TwitterNodes)", + "opId":10, + "relationKey":{ + "userName":"public", + "programName":"adhoc", + "relationName":"TwitterNodes" + } + }, + { + "argChild1": 9, + "argChild2": 10, + "opType":"Sample", + "opId":11, + "opName":"MyriaSampleWR" + }, + { + "opType": "DbInsert", + "opName": "MyriaStore", + "argOverwriteTable": true, + "argChild": 11, + "relationKey": + { + "programName": "adhoc", + "relationName": "SampledScanWR", + "userName": "public" + }, + "opId": 12 + } + ] + } + ] + }, + "rawQuery": "SampleScanWR" +} diff --git a/jsonQueries/radion_queries/SampleScanWoR.json b/jsonQueries/radion_queries/SampleScanWoR.json new file mode 100644 index 000000000..8246a1586 --- /dev/null +++ b/jsonQueries/radion_queries/SampleScanWoR.json @@ -0,0 +1,130 @@ +{ + "logicalRa": "Nothing", + "plan": + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"DbQueryScan", + "opName":"MyriaQueryScan(public:adhoc:TwitterNodes)", + "opId":3, + "schema":{ + "columnNames": + [ + "cnt" + ], + "columnTypes": + [ + "LONG_TYPE" + ] + }, + "sql":"SELECT count(*) AS cnt FROM \"public:adhoc:TwitterNodes\"" + }, + { + "opType":"Apply", + "argChild":3, + "opId":4, + "opName":"MyriaApply(WorkerID=WORKERID)", + "emitExpressions": + [ + { + "outputName":"WorkerID", + "rootExpressionOperator": + { + "type":"WORKERID" + } + }, + { + "outputName":"WorkerCount", + "rootExpressionOperator": + { + "type":"VARIABLE", + "columnIdx":0 + } + } + ] + }, + { + "argChild":4, + "opId":5, + "opType": "CollectProducer" + } + ] + }, + { + "operators": + [ + { + "argOperatorId":5, + "opId":6, + "opType": "CollectConsumer" + }, + { + "opName": "MyriaSamplingDistribution", + "opType": "SamplingDistribution", + "argChild": 6, + "opId": 7, + "sampleSize": 200, + "isWithReplacement": false + }, + { + "argChild":7, + "argPf":{ + "index":0, + "type":"RawValue" + }, + "opType":"ShuffleProducer", + "opId":8, + "opName":"MyriaShuffleProducer($0)" + } + ] + }, + { + "operators": + [ + { + "opName":"MyriaShuffleConsumer", + "opType":"ShuffleConsumer", + "argOperatorId":8, + "opId":9 + }, + { + "opType":"TableScan", + "opName":"MyriaScan(public:adhoc:TwitterNodes)", + "opId":10, + "relationKey":{ + "userName":"public", + "programName":"adhoc", + "relationName":"TwitterNodes" + } + }, + { + "argChild1": 9, + "argChild2": 10, + "opType":"Sample", + "opId":11, + "opName":"MyriaSampleWoR" + }, + { + "opType": "DbInsert", + "opName": "MyriaStore", + "argOverwriteTable": true, + "argChild": 11, + "relationKey": + { + "programName": "adhoc", + "relationName": "SampledScanWoR", + "userName": "public" + }, + "opId": 12 + } + ] + } + ] + }, + "rawQuery": "SampleScanWoR" +} diff --git a/jsonQueries/radion_queries/SampleWR.json b/jsonQueries/radion_queries/SampleWR.json new file mode 100644 index 000000000..82337e8ea --- /dev/null +++ b/jsonQueries/radion_queries/SampleWR.json @@ -0,0 +1,133 @@ +{ + "logicalRa": "Nothing", + "plan": + { + "type":"Sequence", + "plans": + [ + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"TableScan", + "opName":"MyriaScan(public:adhoc:TwitterNodes)", + "opId":0, + "relationKey":{ + "userName":"public", + "programName":"adhoc", + "relationName":"TwitterNodes" + } + }, + { + "argChild": 0, + "opType":"SampledDbInsertTemp", + "opId":1, + "opName":"MyriaSampledDbInsertTemp", + "sampleSize":200, + "sampleTable":"TempSampleWoR", + "countTable":"TempCountSampleWoR" + }, + { + "argChild": 1, + "opType":"SinkRoot", + "opId":2, + "opName":"MyriaSinkRoot" + } + ] + } + ] + }, + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"TempTableScan", + "table":"TempCountSampleWoR", + "opName":"MyriaScanTemp(TempCountSampleWoR)", + "opId":4 + }, + { + "argChild":4, + "opId":6, + "opType": "CollectProducer" + } + ] + }, + { + "operators": + [ + { + "argOperatorId":6, + "opId":7, + "opType": "CollectConsumer" + }, + { + "opName": "MyriaSamplingDistribution", + "opType": "SamplingDistribution", + "argChild": 7, + "opId": 8, + "sampleSize": 200, + "isWithReplacement": true + }, + { + "argChild":8, + "opId":9, + "argPf":{ + "index":0, + "type":"RawValue" + }, + "opType":"ShuffleProducer", + "opName":"MyriaShuffleProducer($0)" + } + ] + }, + { + "operators": + [ + { + "opName":"MyriaShuffleConsumer", + "opType":"ShuffleConsumer", + "argOperatorId":9, + "opId":10 + }, + { + "opType":"TempTableScan", + "table":"TempSampleWoR", + "opName":"MyriaScanTemp(TempSampleWoR)", + "opId":11 + }, + { + "argChild1": 10, + "argChild2": 11, + "opType":"Sample", + "opId":12, + "opName":"MyriaSampleWR" + }, + { + "opType": "DbInsert", + "argOverwriteTable": true, + "argChild": 12, + "relationKey": + { + "programName": "adhoc", + "relationName": "SampledWR", + "userName": "public" + }, + "opId": 13 + } + ] + } + ] + } + ] + }, + "rawQuery": "SampleWR" +} diff --git a/jsonQueries/radion_queries/SampleWoR.json b/jsonQueries/radion_queries/SampleWoR.json new file mode 100644 index 000000000..88a188f99 --- /dev/null +++ b/jsonQueries/radion_queries/SampleWoR.json @@ -0,0 +1,133 @@ +{ + "logicalRa": "Nothing", + "plan": + { + "type":"Sequence", + "plans": + [ + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"TableScan", + "opName":"MyriaScan(public:adhoc:TwitterNodes)", + "opId":0, + "relationKey":{ + "userName":"public", + "programName":"adhoc", + "relationName":"TwitterNodes" + } + }, + { + "argChild": 0, + "opType":"SampledDbInsertTemp", + "opId":1, + "opName":"MyriaSampledDbInsertTemp", + "sampleSize":200, + "sampleTable":"TempSampleWoR", + "countTable":"TempCountSampleWoR" + }, + { + "argChild": 1, + "opType":"SinkRoot", + "opId":2, + "opName":"MyriaSinkRoot" + } + ] + } + ] + }, + { + "type":"SubQuery", + "fragments": + [ + { + "operators": + [ + { + "opType":"TempTableScan", + "table":"TempCountSampleWoR", + "opName":"MyriaScanTemp(TempCountSampleWoR)", + "opId":4 + }, + { + "argChild":4, + "opId":6, + "opType": "CollectProducer" + } + ] + }, + { + "operators": + [ + { + "argOperatorId":6, + "opId":7, + "opType": "CollectConsumer" + }, + { + "opName": "MyriaSamplingDistribution", + "opType": "SamplingDistribution", + "argChild": 7, + "opId": 8, + "sampleSize": 200, + "isWithReplacement": false + }, + { + "argChild":8, + "opId":9, + "argPf":{ + "index":0, + "type":"RawValue" + }, + "opType":"ShuffleProducer", + "opName":"MyriaShuffleProducer($0)" + } + ] + }, + { + "operators": + [ + { + "opName":"MyriaShuffleConsumer", + "opType":"ShuffleConsumer", + "argOperatorId":9, + "opId":10 + }, + { + "opType":"TempTableScan", + "table":"TempSampleWoR", + "opName":"MyriaScanTemp(TempSampleWoR)", + "opId":11 + }, + { + "argChild1": 10, + "argChild2": 11, + "opType":"Sample", + "opId":12, + "opName":"MyriaSampleWoR" + }, + { + "opType": "DbInsert", + "argOverwriteTable": true, + "argChild": 12, + "relationKey": + { + "programName": "adhoc", + "relationName": "SampledWoR", + "userName": "public" + }, + "opId": 13 + } + ] + } + ] + } + ] + }, + "rawQuery": "SampleWoR" +} From 1cd372d01e730ee63dff1cd6721ab1e8abb7bb89 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sun, 24 May 2015 10:57:03 -0700 Subject: [PATCH 06/29] renamed RawValue to IdentityHash --- jsonQueries/radion_queries/SampleScanWR.json | 2 +- jsonQueries/radion_queries/SampleScanWoR.json | 2 +- jsonQueries/radion_queries/SampleWR.json | 2 +- jsonQueries/radion_queries/SampleWoR.json | 2 +- ...nction.java => IdentityHashPartitionFunction.java} | 11 ++++++----- .../operator/network/partition/PartitionFunction.java | 2 +- 6 files changed, 11 insertions(+), 10 deletions(-) rename src/edu/washington/escience/myria/operator/network/partition/{RawValuePartitionFunction.java => IdentityHashPartitionFunction.java} (79%) diff --git a/jsonQueries/radion_queries/SampleScanWR.json b/jsonQueries/radion_queries/SampleScanWR.json index 65257601a..19a8a0200 100644 --- a/jsonQueries/radion_queries/SampleScanWR.json +++ b/jsonQueries/radion_queries/SampleScanWR.json @@ -75,7 +75,7 @@ "argChild":7, "argPf":{ "index":0, - "type":"RawValue" + "type":"IdentityHash" }, "opType":"ShuffleProducer", "opId":8, diff --git a/jsonQueries/radion_queries/SampleScanWoR.json b/jsonQueries/radion_queries/SampleScanWoR.json index 8246a1586..0968ca307 100644 --- a/jsonQueries/radion_queries/SampleScanWoR.json +++ b/jsonQueries/radion_queries/SampleScanWoR.json @@ -75,7 +75,7 @@ "argChild":7, "argPf":{ "index":0, - "type":"RawValue" + "type":"IdentityHash" }, "opType":"ShuffleProducer", "opId":8, diff --git a/jsonQueries/radion_queries/SampleWR.json b/jsonQueries/radion_queries/SampleWR.json index 82337e8ea..d4b021049 100644 --- a/jsonQueries/radion_queries/SampleWR.json +++ b/jsonQueries/radion_queries/SampleWR.json @@ -82,7 +82,7 @@ "opId":9, "argPf":{ "index":0, - "type":"RawValue" + "type":"IdentityHash" }, "opType":"ShuffleProducer", "opName":"MyriaShuffleProducer($0)" diff --git a/jsonQueries/radion_queries/SampleWoR.json b/jsonQueries/radion_queries/SampleWoR.json index 88a188f99..99eb5ef9c 100644 --- a/jsonQueries/radion_queries/SampleWoR.json +++ b/jsonQueries/radion_queries/SampleWoR.json @@ -82,7 +82,7 @@ "opId":9, "argPf":{ "index":0, - "type":"RawValue" + "type":"IdentityHash" }, "opType":"ShuffleProducer", "opName":"MyriaShuffleProducer($0)" diff --git a/src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java b/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java similarity index 79% rename from src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java rename to src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java index fade62ecc..4ceda2040 100644 --- a/src/edu/washington/escience/myria/operator/network/partition/RawValuePartitionFunction.java +++ b/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java @@ -8,9 +8,10 @@ import edu.washington.escience.myria.storage.TupleBatch; /** - * Implementation of a PartitionFunction that decides based on the raw values of an integer field. + * Implementation of a PartitionFunction that use the trivial identity hash. + * i.e. a --> a */ -public final class RawValuePartitionFunction extends PartitionFunction { +public final class IdentityHashPartitionFunction extends PartitionFunction { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; @@ -23,12 +24,12 @@ public final class RawValuePartitionFunction extends PartitionFunction { * @param index the index of the partition field. */ @JsonCreator - public RawValuePartitionFunction( + public IdentityHashPartitionFunction( @JsonProperty(value = "index", required = true) final Integer index) { super(null); this.index = java.util.Objects.requireNonNull(index, "missing property index"); Preconditions.checkArgument(this.index >= 0, - "RawValue field index cannot take negative value %s", this.index); + "IdentityHash field index cannot take negative value %s", this.index); } /** @@ -45,7 +46,7 @@ public int getIndex() { @Override public int[] partition(final TupleBatch tb) { Preconditions.checkArgument(tb.getSchema().getColumnType(index) == Type.INT_TYPE, - "RawValue index column must be of type INT"); + "IdentityHash index column must be of type INT"); final int[] result = new int[tb.numTuples()]; for (int i = 0; i < result.length; i++) { // Offset by -1 because WorkerIDs are 1-indexed. diff --git a/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java b/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java index 4147797d1..877027fa2 100644 --- a/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java +++ b/src/edu/washington/escience/myria/operator/network/partition/PartitionFunction.java @@ -22,7 +22,7 @@ @JsonSubTypes({ @Type(value = RoundRobinPartitionFunction.class, name = "RoundRobin"), @Type(value = SingleFieldHashPartitionFunction.class, name = "SingleFieldHash"), - @Type(value = RawValuePartitionFunction.class, name = "RawValue"), + @Type(value = IdentityHashPartitionFunction.class, name = "IdentityHash"), @Type(value = MultiFieldHashPartitionFunction.class, name = "MultiFieldHash"), @Type(value = WholeTupleHashPartitionFunction.class, name = "WholeTupleHash") }) public abstract class PartitionFunction implements Serializable { From e12b75bc7462845db8b7a93de4b61d832fba630a Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sun, 24 May 2015 13:31:25 -0700 Subject: [PATCH 07/29] support for sampling by a percentage of total tuples --- jsonQueries/radion_queries/SampleScanWR.json | 4 +- jsonQueries/radion_queries/SampleScanWoR.json | 4 +- .../SamplingDistributionEncoding.java | 22 +- .../escience/myria/operator/Sample.java | 20 +- .../myria/operator/SamplingDistribution.java | 114 +++++++--- .../escience/myria/operator/SampleWRTest.java | 6 +- .../myria/operator/SampleWoRTest.java | 6 +- .../operator/SamplingDistributionTest.java | 211 ++++++++++++++---- 8 files changed, 290 insertions(+), 97 deletions(-) diff --git a/jsonQueries/radion_queries/SampleScanWR.json b/jsonQueries/radion_queries/SampleScanWR.json index 19a8a0200..4e449084e 100644 --- a/jsonQueries/radion_queries/SampleScanWR.json +++ b/jsonQueries/radion_queries/SampleScanWR.json @@ -68,8 +68,8 @@ "opType": "SamplingDistribution", "argChild": 6, "opId": 7, - "sampleSize": 200, - "isWithReplacement": true + "samplePercentage": 10, + "sampleType": "WR" }, { "argChild":7, diff --git a/jsonQueries/radion_queries/SampleScanWoR.json b/jsonQueries/radion_queries/SampleScanWoR.json index 0968ca307..cfe46f8e6 100644 --- a/jsonQueries/radion_queries/SampleScanWoR.json +++ b/jsonQueries/radion_queries/SampleScanWoR.json @@ -68,8 +68,8 @@ "opType": "SamplingDistribution", "argChild": 6, "opId": 7, - "sampleSize": 200, - "isWithReplacement": false + "samplePercentage": 10, + "sampleType": "WoR" }, { "argChild":7, diff --git a/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java index a38385e26..323f17667 100644 --- a/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java @@ -1,21 +1,35 @@ package edu.washington.escience.myria.api.encoding; +import javax.ws.rs.core.Response; + +import edu.washington.escience.myria.api.MyriaApiException; import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; import edu.washington.escience.myria.operator.SamplingDistribution; public class SamplingDistributionEncoding extends UnaryOperatorEncoding { - @Required - public int sampleSize; + /** A specific number of tuples to sample. */ + public Integer sampleSize; + + /** Percentage of total tuples to sample. */ + public Float samplePercentage; @Required - public boolean isWithReplacement; + public String sampleType; /** Used to make results deterministic. Null if no specified value. */ public Long randomSeed; @Override public SamplingDistribution construct(final ConstructArgs args) { - return new SamplingDistribution(null, sampleSize, isWithReplacement, randomSeed); + if (sampleSize != null && samplePercentage == null) { + return new SamplingDistribution(null, sampleSize, sampleType, randomSeed); + } else if (sampleSize == null && samplePercentage != null) { + return new SamplingDistribution(null, samplePercentage, sampleType, + randomSeed); + } else { + throw new MyriaApiException(Response.Status.BAD_REQUEST, + "Must specify exactly one of sampleSize or samplePercentage"); + } } } diff --git a/src/edu/washington/escience/myria/operator/Sample.java b/src/edu/washington/escience/myria/operator/Sample.java index 17e6d7891..0890c7a0c 100644 --- a/src/edu/washington/escience/myria/operator/Sample.java +++ b/src/edu/washington/escience/myria/operator/Sample.java @@ -27,8 +27,8 @@ public class Sample extends BinaryOperator { /** Random generator used for index selection. */ private Random rand; - /** True if the sampling is WithReplacement. WithoutReplacement otherwise. */ - private boolean isWithReplacement; + /** The type of sampling to perform. Currently supports 'WR' and 'WoR'. */ + private String sampleType; /** True if operator has extracted sampling info. */ private boolean computedSamplingInfo = false; @@ -50,7 +50,7 @@ public class Sample extends BinaryOperator { * and the stream from the right operator. * * @param left - * inputs a (WorkerID, StreamSize, SampleSize, IsWithReplacement) tuple. + * inputs a (WorkerID, StreamSize, SampleSize, SampleType) tuple. * @param right * tuples that will be sampled from. * @param randomSeed @@ -75,17 +75,19 @@ protected TupleBatch fetchNextReady() throws Exception { getLeft().close(); // Cannot sampleWoR more tuples than there are. - if (!isWithReplacement) { + if (sampleType.equals("WoR")) { Preconditions.checkState(sampleSize <= streamSize, "Cannot SampleWoR %s tuples from a population of size %s", sampleSize, streamSize); } // Generate target indices to accept as samples. - if (isWithReplacement) { + if (sampleType.equals("WR")) { samples = generateIndicesWR(streamSize, sampleSize); - } else { + } else if (sampleType.equals("WoR")) { samples = generateIndicesWoR(streamSize, sampleSize); + } else { + throw new DbException("Invalid sampleType: " + sampleType); } computedSamplingInfo = true; @@ -163,10 +165,10 @@ private void extractSamplingInfo(TupleBatch tb) throws Exception { Preconditions.checkState(sampleSize >= 0, "sampleSize cannot be negative"); Type col3Type = tb.getSchema().getColumnType(3); - if (col3Type == Type.BOOLEAN_TYPE) { - isWithReplacement= tb.getBoolean(3, 0); + if (col3Type == Type.STRING_TYPE) { + sampleType = tb.getString(3, 0); } else { - throw new DbException("IsWithReplacement column must be of type BOOLEAN"); + throw new DbException("SampleType column must be of type STRING"); } } diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 5d5e7bc8e..11f9eb4e9 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -11,8 +11,8 @@ import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.Column; -import edu.washington.escience.myria.column.builder.BooleanColumnBuilder; import edu.washington.escience.myria.column.builder.IntColumnBuilder; +import edu.washington.escience.myria.column.builder.StringColumnBuilder; import edu.washington.escience.myria.storage.TupleBatch; public class SamplingDistribution extends UnaryOperator { @@ -21,25 +21,40 @@ public class SamplingDistribution extends UnaryOperator { /** The output schema. */ private static final Schema SCHEMA = Schema.of( - ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.BOOLEAN_TYPE), - ImmutableList.of("WorkerID", "StreamSize", "SampleSize", "IsWithReplacement")); + ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.STRING_TYPE), + ImmutableList.of("WorkerID", "StreamSize", "SampleSize", "SampleType")); /** Total number of tuples to sample. */ - private final int sampleSize; + private int sampleSize; - /** True if the sampling is WithReplacement. WithoutReplacement otherwise. */ - private final boolean isWithReplacement; + /** True if using a percentage instead of a specific tuple count. */ + private boolean isPercentageSample = false; + + /** Percentage of total tuples to sample. */ + private float samplePercentage; + + /** The type of sampling to perform. Currently supports 'WR' and 'WoR'. */ + private final String sampleType; /** Random generator used for creating the distribution. */ private Random rand; + private SamplingDistribution(Operator child, String sampleType, Long randomSeed) { + super(child); + this.sampleType = sampleType; + this.rand = new Random(); + if (randomSeed != null) { + this.rand.setSeed(randomSeed); + } + } + /** - * Instantiate a SamplingDistribution operator. + * Instantiate a SamplingDistribution operator using a specific sample size. * * @param sampleSize * total samples to create a distribution for. - * @param isWithReplacement - * true if the distribution uses WithReplacement sampling. + * @param sampleType + * the type of sampling distribution to create * @param child * extracts (WorkerID, PartitionSize, StreamSize) information from * this child. @@ -47,16 +62,33 @@ public class SamplingDistribution extends UnaryOperator { * value to seed the random generator with. null if no specified seed */ public SamplingDistribution(Operator child, int sampleSize, - boolean isWithReplacement, Long randomSeed) { - super(child); + String sampleType, Long randomSeed) { + this(child, sampleType, randomSeed); this.sampleSize = sampleSize; - Preconditions.checkState(sampleSize >= 0, - "Sample size cannot be negative: %s", sampleSize); - this.isWithReplacement = isWithReplacement; - this.rand = new Random(); - if (randomSeed != null) { - this.rand.setSeed(randomSeed); - } + Preconditions.checkState(this.sampleSize >= 0, + "Sample Size must be >= 0: %s", this.sampleSize); + } + + /** + * Instantiate a SamplingDistribution operator using a percentage of total tuples. + * + * @param samplePercentage + * percentage of total samples to create a distribution for. + * @param sampleType + * the type of sampling distribution to create + * @param child + * extracts (WorkerID, PartitionSize, StreamSize) information from + * this child. + * @param randomSeed + * value to seed the random generator with. null if no specified seed + */ + public SamplingDistribution(Operator child, float samplePercentage, + String sampleType, Long randomSeed) { + this(child, sampleType, randomSeed); + this.isPercentageSample = true; + this.samplePercentage = samplePercentage; + Preconditions.checkState(samplePercentage >= 0 && samplePercentage <= 100, + "Sample Percentage must be >= 0 && <= 100: %s", samplePercentage); } @Override @@ -130,36 +162,42 @@ protected TupleBatch fetchNextReady() throws DbException { throw new DbException("StreamSize must be of type INT or LONG"); } Preconditions.checkState(partitionSize >= 0, - "Worker cannot have a negative StreamSize: %s", streamSize); + "Worker cannot have a negative StreamSize: %d", streamSize); } streamCounts.set(workerID - 1, streamSize); } } - Preconditions.checkState(sampleSize <= totalTupleCount, + // Convert samplePct to sampleSize if using a percentage sample. + if (isPercentageSample) { + sampleSize = Math.round(totalTupleCount * (samplePercentage / 100)); + } + Preconditions.checkState(sampleSize >= 0 && sampleSize <= totalTupleCount, "Cannot extract %s samples from a population of size %s", sampleSize, totalTupleCount); // Generate a random distribution across the workers. int[] sampleCounts; - if (isWithReplacement) { + if (sampleType.equals("WR")) { sampleCounts = withReplacementDistribution(tupleCounts, sampleSize); - } else { + } else if (sampleType.equals("WoR")){ sampleCounts = withoutReplacementDistribution(tupleCounts, sampleSize); + } else { + throw new DbException("Invalid sampleType: " + sampleType); } // Build and return a TupleBatch with the distribution. IntColumnBuilder wIdCol = new IntColumnBuilder(); IntColumnBuilder streamSizeCol = new IntColumnBuilder(); IntColumnBuilder sampCountCol = new IntColumnBuilder(); - BooleanColumnBuilder wrCol = new BooleanColumnBuilder(); + StringColumnBuilder sampTypeCol = new StringColumnBuilder(); for (int i = 0; i < streamCounts.size(); i++) { wIdCol.appendInt(i + 1); streamSizeCol.appendInt(streamCounts.get(i)); sampCountCol.appendInt(sampleCounts[i]); - wrCol.appendBoolean(isWithReplacement); + sampTypeCol.appendString(sampleType); } ImmutableList.Builder> columns = ImmutableList.builder(); - columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), wrCol.build()); + columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), sampTypeCol.build()); return new TupleBatch(SCHEMA, columns.build()); } @@ -228,6 +266,32 @@ private int[] withoutReplacementDistribution(List tupleCounts, return distribution; } + /** + * Returns the sample size of this operator. If operator was created using a + * samplePercentage, this value will be 0 until after fetchNextReady. + */ + public int getSampleSize() { + return sampleSize; + } + + /** Returns whether this operator is using a percentage sample. */ + public boolean isPercentageSample() { + return isPercentageSample; + } + + /** + * Returns the percentage of total tuples that this operator will distribute. + * Will be 0 if the operator was created using a specific sampleSize. + */ + public float getSamplePercentage() { + return samplePercentage; + } + + /** Returns the type of sampling distribution that this operator will create. */ + public String getSampleType() { + return sampleType; + } + @Override public Schema generateSchema() { return SCHEMA; diff --git a/test/edu/washington/escience/myria/operator/SampleWRTest.java b/test/edu/washington/escience/myria/operator/SampleWRTest.java index e053e1846..f6bb3a454 100644 --- a/test/edu/washington/escience/myria/operator/SampleWRTest.java +++ b/test/edu/washington/escience/myria/operator/SampleWRTest.java @@ -27,7 +27,7 @@ public class SampleWRTest { final Schema LEFT_SCHEMA = Schema.ofFields("WorkerID", Type.INT_TYPE, "PartitionSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, - "IsWithReplacement", Type.BOOLEAN_TYPE); + "SampleType", Type.STRING_TYPE); final Schema RIGHT_SCHEMA = Schema.ofFields(Type.INT_TYPE, "SomeValue"); final Schema OUTPUT_SCHEMA = RIGHT_SCHEMA; @@ -123,7 +123,7 @@ private void verifyExpectedResults(int partitionSize, int sampleSize, int[] expected) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putBoolean(3, true); + leftInput.putString(3, "WR"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -145,7 +145,7 @@ private void drainOperator(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putBoolean(3, true); + leftInput.putString(3, "WR"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); while (!sampOp.eos()) { diff --git a/test/edu/washington/escience/myria/operator/SampleWoRTest.java b/test/edu/washington/escience/myria/operator/SampleWoRTest.java index 0ec6dbe05..365a1978b 100644 --- a/test/edu/washington/escience/myria/operator/SampleWoRTest.java +++ b/test/edu/washington/escience/myria/operator/SampleWoRTest.java @@ -23,7 +23,7 @@ public class SampleWoRTest { final Schema LEFT_SCHEMA = Schema.ofFields("WorkerID", Type.INT_TYPE, "PartitionSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, - "IsWithReplacement", Type.BOOLEAN_TYPE); + "SampleType", Type.STRING_TYPE); final Schema RIGHT_SCHEMA = Schema.ofFields(Type.INT_TYPE, "SomeValue"); final Schema OUTPUT_SCHEMA = RIGHT_SCHEMA; @@ -111,7 +111,7 @@ public void cleanup() throws DbException { private void verifyExpectedResults(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putBoolean(3, false); + leftInput.putString(3, "WoR"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -131,7 +131,7 @@ private void drainOperator(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putBoolean(3, false); + leftInput.putString(3, "WoR"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); while (!sampOp.eos()) { diff --git a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java index a80e290c1..8a146c01e 100644 --- a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java +++ b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java @@ -23,8 +23,9 @@ public class SamplingDistributionTest { final Schema inputSchema = Schema.ofFields("WorkerID", Type.INT_TYPE, "PartitionSize", Type.INT_TYPE); - final Schema expectedResultSchema = Schema.ofFields("WorkerID", Type.INT_TYPE, - "StreamSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, "IsWithReplacement", Type.BOOLEAN_TYPE); + final Schema expectedResultSchema = Schema.ofFields("WorkerID", + Type.INT_TYPE, "StreamSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, + "SampleType", Type.STRING_TYPE); TupleBatchBuffer input; SamplingDistribution sampOp; @@ -47,131 +48,209 @@ public void setup() { @Test public void testSampleWRSizeZero() throws DbException { int sampleSize = 0; - boolean isWithReplacement = true; + String sampleType = "WR"; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; - verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + verifyExpectedResults(sampleSize, sampleType, expectedResults); } @Test public void testSampleWoRSizeZero() throws DbException { int sampleSize = 0; - boolean isWithReplacement = false; + String sampleType = "WoR"; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; - verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + verifyExpectedResults(sampleSize, sampleType, expectedResults); + } + + /** Sample size 0%. */ + @Test + public void testSampleWRPctZero() throws DbException { + float samplePct = 0; + String sampleType = "WR"; + final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, + { 3, 400, 0 }, { 4, 100, 0 } }; + verifyExpectedResults(samplePct, sampleType, expectedResults); + } + + @Test + public void testSampleWoRPctZero() throws DbException { + float samplePct = 0; + String sampleType = "WoR"; + final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, + { 3, 400, 0 }, { 4, 100, 0 } }; + verifyExpectedResults(samplePct, sampleType, expectedResults); } /** Sample size 1. */ @Test public void testSampleWRSizeOne() throws DbException { int sampleSize = 1; - boolean isWithReplacement = true; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WR"; + verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeOne() throws DbException { int sampleSize = 1; - boolean isWithReplacement = false; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WoR"; + verifyPossibleDistribution(sampleSize, sampleType); } /** Sample size 50. */ @Test public void testSampleWRSizeFifty() throws DbException { int sampleSize = 50; - boolean isWithReplacement = true; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WR"; + verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeFifty() throws DbException { int sampleSize = 50; - boolean isWithReplacement = false; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WoR"; + verifyPossibleDistribution(sampleSize, sampleType); + } + + /** Sample size 50%. */ + @Test + public void testSampleWRPctFifty() throws DbException { + float samplePct = 50; + String sampleType = "WR"; + verifyPossibleDistribution(samplePct, sampleType); + } + + @Test + public void testSampleWoRPctFifty() throws DbException { + float samplePct = 50; + String sampleType = "WoR"; + verifyPossibleDistribution(samplePct, sampleType); } /** Sample all but one tuple. */ @Test public void testSampleWoRSizeAllButOne() throws DbException { int sampleSize = 999; - boolean isWithReplacement = false; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WoR"; + verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWRSizeAllButOne() throws DbException { int sampleSize = 999; - boolean isWithReplacement = true; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WR"; + verifyPossibleDistribution(sampleSize, sampleType); } /** SamplingWoR the entire population == return all. */ @Test public void testSampleWoRSizeMax() throws DbException { int sampleSize = 1000; - boolean isWithReplacement = false; + String sampleType = "WoR"; final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, { 3, 400, 400 }, { 4, 100, 100 } }; - verifyExpectedResults(sampleSize, isWithReplacement, expectedResults); + verifyExpectedResults(sampleSize, sampleType, expectedResults); + } + + @Test + public void testSampleWoRPctMax() throws DbException { + float samplePct = 100; + String sampleType = "WoR"; + final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, + { 3, 400, 400 }, { 4, 100, 100 } }; + verifyExpectedResults(samplePct, sampleType, expectedResults); } /** SamplingWR the entire population. */ @Test public void testSampleWRSizeMax() throws DbException { int sampleSize = 1000; - boolean isWithReplacement = true; - verifyPossibleDistribution(sampleSize, isWithReplacement); + String sampleType = "WR"; + verifyPossibleDistribution(sampleSize, sampleType); + } + + @Test + public void testSampleWRPctMax() throws DbException { + float samplePct = 100; + String sampleType = "WR"; + verifyPossibleDistribution(samplePct, sampleType); } /** Cannot sample more than total size. */ @Test(expected = IllegalStateException.class) public void testSampleWoRSizeTooMany() throws DbException { int sampleSize = 1001; - boolean isWithReplacement = false; - drainOperator(sampleSize, isWithReplacement); + String sampleType = "WoR"; + drainOperator(sampleSize, sampleType); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWoRPctTooMany() throws DbException { + float samplePct = 100.1f; + String sampleType = "WoR"; + drainOperator(samplePct, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRSizeTooMany() throws DbException { int sampleSize = 1001; - boolean isWithReplacement = true; - drainOperator(sampleSize, isWithReplacement); + String sampleType = "WR"; + drainOperator(sampleSize, sampleType); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWRPctTooMany() throws DbException { + float samplePct = 100.1f; + String sampleType = "WR"; + drainOperator(samplePct, sampleType); } /** Cannot sample a negative number of samples. */ @Test(expected = IllegalStateException.class) public void testSampleWoRSizeNegative() throws DbException { int sampleSize = -1; - boolean isWithReplacement = false; - drainOperator(sampleSize, isWithReplacement); + String sampleType = "WoR"; + drainOperator(sampleSize, sampleType); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWoRPctNegative() throws DbException { + float samplePct = -0.01f; + String sampleType = "WoR"; + drainOperator(samplePct, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRSizeNegative() throws DbException { int sampleSize = -1; - boolean isWithReplacement = false; - drainOperator(sampleSize, isWithReplacement); + String sampleType = "WoR"; + drainOperator(sampleSize, sampleType); + } + + @Test(expected = IllegalStateException.class) + public void testSampleWRPctNegative() throws DbException { + float samplePct = -0.01f; + String sampleType = "WoR"; + drainOperator(samplePct, sampleType); } /** Worker cannot report a negative partition size. */ @Test(expected = IllegalStateException.class) public void testSampleWoRWorkerNegative() throws DbException { int sampleSize = 50; - boolean isWithReplacement = false; + String sampleType = "WoR"; input.putInt(0, 5); input.putInt(1, -1); - drainOperator(sampleSize, isWithReplacement); + drainOperator(sampleSize, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRWorkerNegative() throws DbException { int sampleSize = 50; - boolean isWithReplacement = true; + String sampleType = "WR"; input.putInt(0, 5); input.putInt(1, -1); - drainOperator(sampleSize, isWithReplacement); + drainOperator(sampleSize, sampleType); } @After @@ -182,11 +261,8 @@ public void cleanup() throws DbException { } /** Compare output results compared to some known expectedResults. */ - private void verifyExpectedResults(int sampleSize, - boolean isWithReplacement, int[][] expectedResults) throws DbException { - sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); - sampOp.open(TestEnvVars.get()); - + private void verifyExpectedResults(SamplingDistribution sampOp, + int[][] expectedResults) throws DbException { int rowIdx = 0; while (!sampOp.eos()) { TupleBatch result = sampOp.nextReady(); @@ -201,16 +277,27 @@ private void verifyExpectedResults(int sampleSize, } assertEquals(expectedResults.length, rowIdx); } + private void verifyExpectedResults(int sampleSize, String sampleType, + int[][] expectedResults) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, + sampleType, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + verifyExpectedResults(sampOp, expectedResults); + } + private void verifyExpectedResults(float samplePct, String sampleType, + int[][] expectedResults) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), samplePct, + sampleType, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + verifyExpectedResults(sampOp, expectedResults); + } /** * Tests the actual distribution against what could be possible. Note: doesn't * test if it is statistically random. */ - private void verifyPossibleDistribution(int sampleSize, - boolean isWithReplacement) throws DbException { - sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); - sampOp.open(TestEnvVars.get()); - + private void verifyPossibleDistribution(SamplingDistribution sampOp) + throws DbException { int rowIdx = 0; int computedSampleSize = 0; while (!sampOp.eos()) { @@ -218,8 +305,8 @@ private void verifyPossibleDistribution(int sampleSize, if (result != null) { assertEquals(expectedResultSchema, result.getSchema()); for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { - assert (result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampleSize); - if (!isWithReplacement) { + assert (result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampOp.getSampleSize()); + if (sampOp.getSampleType().equals("WoR")) { // SampleWoR cannot sample more than worker's population size. assert (result.getInt(2, i) <= result.getInt(1, i)); } @@ -228,13 +315,39 @@ private void verifyPossibleDistribution(int sampleSize, } } assertEquals(input.numTuples(), rowIdx); - assertEquals(sampleSize, computedSampleSize); + assertEquals(sampOp.getSampleSize(), computedSampleSize); + } + private void verifyPossibleDistribution(int sampleSize, + String sampleType) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, + sampleType, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + verifyPossibleDistribution(sampOp); + } + + private void verifyPossibleDistribution(float samplePct, + String sampleType) throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), samplePct, + sampleType, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + verifyPossibleDistribution(sampOp); } /** Run through all results without doing anything. */ - private void drainOperator(int sampleSize, boolean isWithReplacement) + private void drainOperator(int sampleSize, String sampleType) + throws DbException { + sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, + sampleType, RANDOM_SEED); + sampOp.open(TestEnvVars.get()); + while (!sampOp.eos()) { + sampOp.nextReady(); + } + } + + private void drainOperator(float samplePct, String sampleType) throws DbException { - sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, isWithReplacement, RANDOM_SEED); + sampOp = new SamplingDistribution(new TupleSource(input), samplePct, + sampleType, RANDOM_SEED); sampOp.open(TestEnvVars.get()); while (!sampOp.eos()) { sampOp.nextReady(); From 48fd3be407614e61ce7d7f876a30826fd0fc8096 Mon Sep 17 00:00:00 2001 From: Yuqing Guo Date: Thu, 28 May 2015 17:49:17 -0700 Subject: [PATCH 08/29] changed the way to set eos for Limit --- .../escience/myria/operator/Limit.java | 32 ++++++++++--------- .../operator/agg/StreamingAggregate.java | 14 ++++---- .../escience/myria/operator/LimitTest.java | 26 ++++++++------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/Limit.java b/src/edu/washington/escience/myria/operator/Limit.java index ca9bbd88c..2e8e18df0 100644 --- a/src/edu/washington/escience/myria/operator/Limit.java +++ b/src/edu/washington/escience/myria/operator/Limit.java @@ -43,23 +43,25 @@ public Limit(@Nonnull final Long limit, final Operator child) { @Override protected TupleBatch fetchNextReady() throws DbException { Operator child = getChild(); - TupleBatch tb = child.nextReady(); - TupleBatch result = null; - if (tb != null) { - if (tb.numTuples() <= toEmit) { - toEmit -= tb.numTuples(); - result = tb; - } else if (toEmit > 0) { - result = tb.prefix(Ints.checkedCast(toEmit)); - toEmit = 0; - } - if (toEmit == 0) { - /* Close child and self. No more stream is needed. */ - child.close(); - close(); + if (child.isOpen()) { + TupleBatch tb = child.nextReady(); + TupleBatch result = null; + if (tb != null) { + if (tb.numTuples() <= toEmit) { + toEmit -= tb.numTuples(); + result = tb; + } else if (toEmit > 0) { + result = tb.prefix(Ints.checkedCast(toEmit)); + toEmit = 0; + } + if (toEmit == 0) { + /* Close child. No more stream is needed. */ + child.close(); + } } + return result; } - return result; + return null; } @Override diff --git a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java index 65875e650..586279983 100644 --- a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java +++ b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java @@ -2,6 +2,7 @@ import java.util.Objects; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import com.google.common.base.Preconditions; @@ -64,13 +65,14 @@ public class StreamingAggregate extends UnaryOperator { * @param gfields The columns over which we are grouping the result. * @param factories The factories that will produce the {@link Aggregator}s for each group. */ - public StreamingAggregate(@Nullable final Operator child, final int[] gfields, final AggregatorFactory... factories) { + public StreamingAggregate(@Nullable final Operator child, @Nonnull final int[] gfields, + @Nonnull final AggregatorFactory... factories) { super(child); gFields = Objects.requireNonNull(gfields, "gfields"); gTypes = new Type[gfields.length]; this.factories = Objects.requireNonNull(factories, "factories"); Preconditions.checkArgument(gfields.length > 0, " must have at least one group by field"); - Preconditions.checkArgument(factories.length != 0, "to use StreamingAggregate, must specify some aggregates"); + Preconditions.checkArgument(factories.length > 0, "to use StreamingAggregate, must specify some aggregates"); gRange = new int[gfields.length]; for (int i = 0; i < gfields.length; ++i) { gRange[i] = i; @@ -95,9 +97,7 @@ protected TupleBatch fetchNextReady() throws DbException { while (tb != null) { for (int row = 0; row < tb.numTuples(); ++row) { if (curGroupKey == null) { - /* - * first time accessing this tb, no aggregation performed previously - */ + /* First time accessing this tb, no aggregation performed previously. */ // store current group key as a tuple curGroupKey = new Tuple(groupSchema); for (int gKey = 0; gKey < gFields.length; ++gKey) { @@ -127,9 +127,7 @@ protected TupleBatch fetchNextReady() throws DbException { } } } else if (!TupleUtils.tupleEquals(tb, gFields, row, curGroupKey, gRange, 0)) { - /* - * different grouping key than current one, flush current agg result to result buffer - */ + /* Different grouping key than current one, flush current agg result to result buffer. */ addToResult(); // store current group key as a tuple for (int gKey = 0; gKey < gFields.length; ++gKey) { diff --git a/test/edu/washington/escience/myria/operator/LimitTest.java b/test/edu/washington/escience/myria/operator/LimitTest.java index 29cedf10e..50e1f0bfe 100644 --- a/test/edu/washington/escience/myria/operator/LimitTest.java +++ b/test/edu/washington/escience/myria/operator/LimitTest.java @@ -4,10 +4,12 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import java.util.LinkedList; +import java.util.List; import org.junit.Test; +import com.google.common.collect.ImmutableList; + import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; @@ -24,12 +26,13 @@ public void testWithinBatchSizeLimit() throws DbException { TupleSource source = new TupleSource(TestUtils.range(total)); Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); - long count = 0; TupleBatch tb = limiter.nextReady(); - count += tb.numTuples(); - assertEquals(limit, count); + assertEquals(limit, tb.numTuples()); + // reached limit, limiter gets eos and next call to nextReady() returns null + tb = limiter.nextReady(); + assertNull(tb); assertTrue(limiter.eos()); - /* Limit closes itself as soon as it returned # tuples == limit. */ + limiter.close(); } @Test @@ -43,7 +46,7 @@ public void testLimitZero() throws DbException { TupleBatch tb = limiter.nextReady(); assertNull(tb); assertTrue(limiter.eos()); - /* Limit closes itself as soon as it returned # tuples == limit. */ + limiter.close(); } @Test @@ -52,16 +55,17 @@ public void testLimitNumTuples() throws DbException { final long limit = total; TupleBatchBuffer tbb1 = TestUtils.range((int) limit); TupleBatchBuffer tbb2 = TestUtils.range((int) limit); - LinkedList sourceList = new LinkedList(); - sourceList.add(tbb1.popAny()); - sourceList.add(tbb2.popAny()); + List sourceList = ImmutableList.of(tbb1.popAny(), tbb2.popAny()); TupleSource source = new TupleSource(sourceList); Limit limiter = new Limit(limit, source); limiter.open(TestEnvVars.get()); TupleBatch tb = limiter.nextReady(); assertEquals(limit, tb.numTuples()); + // reached limit, limiter gets eos and next call to nextReady() returns null + tb = limiter.nextReady(); + assertNull(tb); assertTrue(limiter.eos()); - /* Limit closes itself as soon as it returned # tuples == limit. */ + limiter.close(); } @Test @@ -84,6 +88,6 @@ public void testSimplePrefix() throws DbException { } assertEquals(limit, count); assertEquals(2, numIteration); - /* Limit closes itself as soon as it returned # tuples == limit. */ + limiter.close(); } } \ No newline at end of file From 2e7b9dff5852ad4f70593eed069579b88241952e Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 2 Jun 2015 18:14:12 +0200 Subject: [PATCH 09/29] Copy db query scan --- .../myria/operator/CatalogQueryScan.java | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 src/edu/washington/escience/myria/operator/CatalogQueryScan.java diff --git a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java new file mode 100644 index 000000000..b3212bd30 --- /dev/null +++ b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java @@ -0,0 +1,252 @@ +package edu.washington.escience.myria.operator; + +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.MyriaConstants; +import edu.washington.escience.myria.RelationKey; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.accessmethod.AccessMethod; +import edu.washington.escience.myria.accessmethod.ConnectionInfo; +import edu.washington.escience.myria.storage.TupleBatch; + +/** + * Push a select query down into a JDBC based database and scan over the query result. + * */ +public class CatalogQueryScan extends LeafOperator implements DbReader { + + /** + * The connection info. + */ + private ConnectionInfo connectionInfo; + + /** + * The name of the relation (RelationKey) for a SELECT * query. + */ + private RelationKey relationKey; + + /** + * Iterate over data from the JDBC database. + * */ + private transient Iterator tuples; + /** + * The result schema. + * */ + private final Schema outputSchema; + + /** + * The SQL template. + * */ + private String baseSQL; + + /** + * Column indexes that the output should be ordered by. + */ + private final int[] sortedColumns; + + /** + * True for each column in {@link #sortedColumns} that should be ordered ascending. + */ + private final boolean[] ascending; + + /** Required for Java serialization. */ + private static final long serialVersionUID = 1L; + + /** The logger for debug, trace, etc. messages in this class. */ + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(CatalogQueryScan.class); + + /** + * Constructor. + * + * @param baseSQL see the corresponding field. + * @param outputSchema see the corresponding field. + * */ + public CatalogQueryScan(final String baseSQL, final Schema outputSchema) { + Objects.requireNonNull(baseSQL); + Objects.requireNonNull(outputSchema); + + this.baseSQL = baseSQL; + this.outputSchema = outputSchema; + connectionInfo = null; + tuples = null; + sortedColumns = null; + ascending = null; + } + + /** + * Constructor that receives the connection info as input. + * + * @param connectionInfo see the corresponding field. + * @param baseSQL see the corresponding field. + * @param outputSchema see the corresponding field. + * */ + public CatalogQueryScan(final ConnectionInfo connectionInfo, final String baseSQL, final Schema outputSchema) { + this(baseSQL, outputSchema); + Objects.requireNonNull(connectionInfo); + this.connectionInfo = connectionInfo; + } + + /** + * Construct a new DbQueryScan object that simply runs SELECT * FROM relationKey. + * + * @param relationKey the relation to be scanned. + * @param outputSchema the Schema of the returned tuples. + */ + public CatalogQueryScan(final RelationKey relationKey, final Schema outputSchema) { + Objects.requireNonNull(relationKey); + Objects.requireNonNull(outputSchema); + + this.relationKey = relationKey; + this.outputSchema = outputSchema; + baseSQL = null; + connectionInfo = null; + tuples = null; + sortedColumns = null; + ascending = null; + } + + /** + * Construct a new DbQueryScan object that simply runs SELECT * FROM relationKey, but receiving the + * connection info as input. + * + * @param connectionInfo the connection information. + * @param relationKey the relation to be scanned. + * @param outputSchema the Schema of the returned tuples. + */ + public CatalogQueryScan(final ConnectionInfo connectionInfo, final RelationKey relationKey, final Schema outputSchema) { + this(relationKey, outputSchema); + Objects.requireNonNull(connectionInfo); + this.connectionInfo = connectionInfo; + } + + /** + * Construct a new DbQueryScan object that runs SELECT * FROM relationKey ORDER BY [...]. + * + * @param relationKey the relation to be scanned. + * @param outputSchema the Schema of the returned tuples. + * @param sortedColumns the columns by which the tuples should be ordered by. + * @param ascending true for columns that should be ordered ascending. + */ + public CatalogQueryScan(final RelationKey relationKey, final Schema outputSchema, final int[] sortedColumns, + final boolean[] ascending) { + Objects.requireNonNull(relationKey); + Objects.requireNonNull(outputSchema); + + this.relationKey = relationKey; + this.outputSchema = outputSchema; + this.sortedColumns = sortedColumns; + this.ascending = ascending; + baseSQL = null; + connectionInfo = null; + tuples = null; + } + + /** + * Construct a new DbQueryScan object that runs SELECT * FROM relationKey ORDER BY [...], but receiving + * the connection info as input. + * + * @param connectionInfo the connection information. + * @param relationKey the relation to be scanned. + * @param outputSchema the Schema of the returned tuples. + * @param sortedColumns the columns by which the tuples should be ordered by. + * @param ascending true for columns that should be ordered ascending. + */ + public CatalogQueryScan(final ConnectionInfo connectionInfo, final RelationKey relationKey, final Schema outputSchema, + final int[] sortedColumns, final boolean[] ascending) { + this(relationKey, outputSchema, sortedColumns, ascending); + Objects.requireNonNull(connectionInfo); + this.connectionInfo = connectionInfo; + } + + @Override + public final void cleanup() { + tuples = null; + } + + @Override + protected final TupleBatch fetchNextReady() throws DbException { + Objects.requireNonNull(connectionInfo); + if (tuples == null) { + tuples = + AccessMethod.of(connectionInfo.getDbms(), connectionInfo, true).tupleBatchIteratorFromQuery(baseSQL, + outputSchema); + } + if (tuples.hasNext()) { + final TupleBatch tb = tuples.next(); + LOGGER.trace("Got {} tuples", tb.numTuples()); + return tb; + } else { + return null; + } + } + + @Override + public final Schema generateSchema() { + return outputSchema; + } + + @Override + protected final void init(final ImmutableMap execEnvVars) throws DbException { + if (connectionInfo == null) { + final String dbms = (String) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_SYSTEM); + if (dbms == null) { + throw new DbException("Unable to instantiate DbQueryScan: database system unknown"); + } + + connectionInfo = (ConnectionInfo) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); + if (connectionInfo == null) { + throw new DbException("Unable to instantiate DbQueryScan: connection information unknown"); + } + + if (!dbms.equals(connectionInfo.getDbms())) { + throw new DbException( + "Unable to instantiate DbQueryScan: database system does not conform with connection information"); + } + } + + if (relationKey != null) { + baseSQL = "SELECT * FROM " + relationKey.toString(connectionInfo.getDbms()); + + String prefix = ""; + if (sortedColumns != null && sortedColumns.length > 0) { + Preconditions.checkArgument(sortedColumns.length == ascending.length); + StringBuilder orderByClause = new StringBuilder(" ORDER BY"); + + for (int columnIdx : sortedColumns) { + orderByClause.append(prefix + " " + getSchema().getColumnName(columnIdx)); + if (ascending[columnIdx]) { + orderByClause.append(" ASC"); + } else { + orderByClause.append(" DESC"); + } + + prefix = ","; + } + + baseSQL = baseSQL.concat(orderByClause.toString()); + } + } + } + + /** + * @return the connection info in this DbQueryScan. + */ + public ConnectionInfo getConnectionInfo() { + return connectionInfo; + } + + @Override + public Set readSet() { + if (relationKey == null) { + LOGGER.error("DbQueryScan does not support the DbReader interface properly for SQL queries."); + return ImmutableSet.of(); + } + return ImmutableSet.of(relationKey); + } +} From 944f6537c7510fb42d33bda9122895c77fe0f2ce Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 2 Jun 2015 18:25:05 +0200 Subject: [PATCH 10/29] Add catalog scan encoding --- .../api/encoding/CatalogScanEncoding.java | 18 ++++++++++++++++++ .../myria/api/encoding/OperatorEncoding.java | 1 + 2 files changed, 19 insertions(+) create mode 100644 src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java diff --git a/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java b/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java new file mode 100644 index 000000000..1c0364eab --- /dev/null +++ b/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java @@ -0,0 +1,18 @@ +package edu.washington.escience.myria.api.encoding; + +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; +import edu.washington.escience.myria.operator.DbQueryScan; + +public class CatalogScanEncoding extends LeafOperatorEncoding { + @Required + public Schema schema; + @Required + public String sql; + + @Override + public DbQueryScan construct(ConstructArgs args) { + return new DbQueryScan(sql, schema); + } + +} \ No newline at end of file diff --git a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java index 236773e71..8f6a9e59d 100644 --- a/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/OperatorEncoding.java @@ -25,6 +25,7 @@ @Type(name = "BinaryFileScan", value = BinaryFileScanEncoding.class), @Type(name = "BroadcastConsumer", value = BroadcastConsumerEncoding.class), @Type(name = "BroadcastProducer", value = BroadcastProducerEncoding.class), + @Type(name = "CatalogScan", value = CatalogScanEncoding.class), @Type(name = "CollectConsumer", value = CollectConsumerEncoding.class), @Type(name = "CollectProducer", value = CollectProducerEncoding.class), @Type(name = "Consumer", value = ConsumerEncoding.class), @Type(name = "Counter", value = CounterEncoding.class), From 043bc232171a06d92b94baa5c3fd9b702995e965 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 2 Jun 2015 19:38:33 +0200 Subject: [PATCH 11/29] Move sqlitetuplebatchiterator into its own class --- .../accessmethod/SQLiteAccessMethod.java | 100 --------------- .../SQLiteTupleBatchIterator.java | 117 ++++++++++++++++++ 2 files changed, 117 insertions(+), 100 deletions(-) create mode 100644 src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java b/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java index 184906d6a..d2fd64741 100644 --- a/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java @@ -2,7 +2,6 @@ import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -25,9 +24,6 @@ import edu.washington.escience.myria.RelationKey; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; -import edu.washington.escience.myria.column.Column; -import edu.washington.escience.myria.column.builder.ColumnBuilder; -import edu.washington.escience.myria.column.builder.ColumnFactory; import edu.washington.escience.myria.storage.TupleBatch; /** @@ -456,99 +452,3 @@ public void createIndexIfNotExists(final RelationKey relationKey, final Schema s throw new UnsupportedOperationException("create index if not exists is not supported in sqlite yet, implement me"); } } - -/** - * Wraps a SQLiteStatement result set in a Iterator. - * - * - */ -class SQLiteTupleBatchIterator implements Iterator { - /** The logger for this class. Uses SQLiteAccessMethod settings. */ - private static final Logger LOGGER = LoggerFactory.getLogger(SQLiteAccessMethod.class); - /** The results from a SQLite query that will be returned in TupleBatches by this Iterator. */ - private final SQLiteStatement statement; - /** The connection to the SQLite database. */ - private final SQLiteConnection connection; - /** The Schema of the TupleBatches returned by this Iterator. */ - private final Schema schema; - - /** - * Wraps a SQLiteStatement result set in an Iterator. - * - * @param statement the SQLiteStatement containing the results. - * @param schema the Schema describing the format of the TupleBatch containing these results. - * @param connection the connection to the SQLite database. - */ - SQLiteTupleBatchIterator(final SQLiteStatement statement, final Schema schema, final SQLiteConnection connection) { - this.statement = statement; - this.connection = connection; - this.schema = schema; - } - - /** - * Wraps a SQLiteStatement result set in an Iterator. - * - * @param statement the SQLiteStatement containing the results. If it has not yet stepped, this constructor will step - * it. Then the Schema of the generated TupleBatchs will be extracted from the statement. - * @param connection the connection to the SQLite database. - * @param schema the Schema describing the format of the TupleBatch containing these results. - */ - SQLiteTupleBatchIterator(final SQLiteStatement statement, final SQLiteConnection connection, final Schema schema) { - this.connection = connection; - this.statement = statement; - try { - if (!statement.hasStepped()) { - statement.step(); - } - this.schema = schema; - } catch (final SQLiteException e) { - throw new RuntimeException(e); - } - } - - @Override - public boolean hasNext() { - final boolean hasRow = statement.hasRow(); - if (!hasRow) { - statement.dispose(); - connection.dispose(); - } - return hasRow; - } - - @Override - public TupleBatch next() { - /* Allocate TupleBatch parameters */ - final int numFields = schema.numColumns(); - final List> columnBuilders = ColumnFactory.allocateColumns(schema); - - /** - * Loop through resultSet, adding one row at a time. Stop when numTuples hits BATCH_SIZE or there are no more - * results. - */ - int numTuples; - try { - for (numTuples = 0; numTuples < TupleBatch.BATCH_SIZE && statement.hasRow(); ++numTuples) { - for (int column = 0; column < numFields; ++column) { - columnBuilders.get(column).appendFromSQLite(statement, column); - } - statement.step(); - } - } catch (final SQLiteException e) { - LOGGER.error("Got SQLiteException:" + e + "in TupleBatchIterator.next()"); - throw new RuntimeException(e); - } - - List> columns = new ArrayList>(columnBuilders.size()); - for (ColumnBuilder cb : columnBuilders) { - columns.add(cb.build()); - } - - return new TupleBatch(schema, columns, numTuples); - } - - @Override - public void remove() { - throw new UnsupportedOperationException("SQLiteTupleBatchIterator.remove()"); - } -} diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java new file mode 100644 index 000000000..24a6fe1b2 --- /dev/null +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java @@ -0,0 +1,117 @@ +/** + * + */ +package edu.washington.escience.myria.accessmethod; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.almworks.sqlite4java.SQLiteConnection; +import com.almworks.sqlite4java.SQLiteException; +import com.almworks.sqlite4java.SQLiteStatement; + +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.column.Column; +import edu.washington.escience.myria.column.builder.ColumnBuilder; +import edu.washington.escience.myria.column.builder.ColumnFactory; +import edu.washington.escience.myria.storage.TupleBatch; + +/** + * Wraps a SQLiteStatement result set in a Iterator. + * + */ +public class SQLiteTupleBatchIterator implements Iterator { + /** The logger for this class. Uses SQLiteAccessMethod settings. */ + private static final Logger LOGGER = LoggerFactory.getLogger(SQLiteAccessMethod.class); + /** The results from a SQLite query that will be returned in TupleBatches by this Iterator. */ + private final SQLiteStatement statement; + /** The connection to the SQLite database. */ + private final SQLiteConnection connection; + /** The Schema of the TupleBatches returned by this Iterator. */ + private final Schema schema; + + /** + * Wraps a SQLiteStatement result set in an Iterator. + * + * @param statement the SQLiteStatement containing the results. + * @param schema the Schema describing the format of the TupleBatch containing these results. + * @param connection the connection to the SQLite database. + */ + SQLiteTupleBatchIterator(final SQLiteStatement statement, final Schema schema, final SQLiteConnection connection) { + this.statement = statement; + this.connection = connection; + this.schema = schema; + } + + /** + * Wraps a SQLiteStatement result set in an Iterator. + * + * @param statement the SQLiteStatement containing the results. If it has not yet stepped, this constructor will step + * it. Then the Schema of the generated TupleBatchs will be extracted from the statement. + * @param connection the connection to the SQLite database. + * @param schema the Schema describing the format of the TupleBatch containing these results. + */ + public SQLiteTupleBatchIterator(final SQLiteStatement statement, final SQLiteConnection connection, + final Schema schema) { + this.connection = connection; + this.statement = statement; + try { + if (!statement.hasStepped()) { + statement.step(); + } + this.schema = schema; + } catch (final SQLiteException e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean hasNext() { + final boolean hasRow = statement.hasRow(); + if (!hasRow) { + statement.dispose(); + connection.dispose(); + } + return hasRow; + } + + @Override + public TupleBatch next() { + /* Allocate TupleBatch parameters */ + final int numFields = schema.numColumns(); + final List> columnBuilders = ColumnFactory.allocateColumns(schema); + + /** + * Loop through resultSet, adding one row at a time. Stop when numTuples hits BATCH_SIZE or there are no more + * results. + */ + int numTuples; + try { + for (numTuples = 0; numTuples < TupleBatch.BATCH_SIZE && statement.hasRow(); ++numTuples) { + for (int column = 0; column < numFields; ++column) { + columnBuilders.get(column).appendFromSQLite(statement, column); + } + statement.step(); + } + } catch (final SQLiteException e) { + LOGGER.error("Got SQLiteException:" + e + "in TupleBatchIterator.next()"); + throw new RuntimeException(e); + } + + List> columns = new ArrayList>(columnBuilders.size()); + for (ColumnBuilder cb : columnBuilders) { + columns.add(cb.build()); + } + + return new TupleBatch(schema, columns, numTuples); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("SQLiteTupleBatchIterator.remove()"); + } +} \ No newline at end of file From b5c47400c506ff279d63cb7241e8005b04fb1daf Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 2 Jun 2015 19:39:08 +0200 Subject: [PATCH 12/29] Implement catalog query scan with test --- .../api/encoding/CatalogScanEncoding.java | 8 +- .../coordinator/catalog/MasterCatalog.java | 27 +++ .../myria/operator/CatalogQueryScan.java | 200 ++---------------- .../escience/myria/parallel/Server.java | 7 + .../myria/operator/CatalogScanTest.java | 66 ++++++ 5 files changed, 125 insertions(+), 183 deletions(-) create mode 100644 test/edu/washington/escience/myria/operator/CatalogScanTest.java diff --git a/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java b/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java index 1c0364eab..afcb80101 100644 --- a/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/CatalogScanEncoding.java @@ -2,17 +2,17 @@ import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; -import edu.washington.escience.myria.operator.DbQueryScan; +import edu.washington.escience.myria.operator.CatalogQueryScan; -public class CatalogScanEncoding extends LeafOperatorEncoding { +public class CatalogScanEncoding extends LeafOperatorEncoding { @Required public Schema schema; @Required public String sql; @Override - public DbQueryScan construct(ConstructArgs args) { - return new DbQueryScan(sql, schema); + public CatalogQueryScan construct(final ConstructArgs args) { + return new CatalogQueryScan(sql, schema, args.getServer().getCatalog()); } } \ No newline at end of file diff --git a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java index 7be77fb24..39a545c68 100644 --- a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java +++ b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -39,6 +40,7 @@ import edu.washington.escience.myria.RelationKey; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.accessmethod.SQLiteTupleBatchIterator; import edu.washington.escience.myria.api.MyriaJsonMapperProvider; import edu.washington.escience.myria.api.encoding.DatasetStatus; import edu.washington.escience.myria.api.encoding.QueryEncoding; @@ -48,6 +50,7 @@ import edu.washington.escience.myria.parallel.RelationWriteMetadata; import edu.washington.escience.myria.parallel.SocketInfo; import edu.washington.escience.myria.parallel.SubQueryId; +import edu.washington.escience.myria.storage.TupleBatch; /** * This class is intended to store the configuration information for a Myria installation. @@ -1918,4 +1921,28 @@ protected String job(final SQLiteConnection sqliteConnection) throws CatalogExce throw new CatalogException(e); } } + + /** + * Run q query on the catalog. + * + * @param queryString a SQL query on the catalog + * @param outputSchema the schema of the query result + * @return a tuple iterator over the result + * @throws CatalogException if there is an error. + */ + public Iterator tupleBatchIteratorFromQuery(final String queryString, final Schema outputSchema) + throws CatalogException { + try { + return queue.execute(new SQLiteJob() { + @Override + protected SQLiteTupleBatchIterator job(final SQLiteConnection sqliteConnection) throws CatalogException, + SQLiteException { + SQLiteStatement statement = sqliteConnection.prepare(queryString); + return new SQLiteTupleBatchIterator(statement, sqliteConnection, outputSchema); + } + }).get(); + } catch (InterruptedException | ExecutionException e) { + throw new CatalogException(e); + } + } } diff --git a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java index b3212bd30..5a25da552 100644 --- a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java +++ b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java @@ -2,58 +2,39 @@ import java.util.Iterator; import java.util.Objects; -import java.util.Set; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import edu.washington.escience.myria.DbException; -import edu.washington.escience.myria.MyriaConstants; -import edu.washington.escience.myria.RelationKey; import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.accessmethod.AccessMethod; -import edu.washington.escience.myria.accessmethod.ConnectionInfo; +import edu.washington.escience.myria.coordinator.catalog.CatalogException; +import edu.washington.escience.myria.coordinator.catalog.MasterCatalog; import edu.washington.escience.myria.storage.TupleBatch; /** * Push a select query down into a JDBC based database and scan over the query result. * */ -public class CatalogQueryScan extends LeafOperator implements DbReader { +public class CatalogQueryScan extends LeafOperator { /** - * The connection info. + * Iterate over data from the catalog. */ - private ConnectionInfo connectionInfo; - - /** - * The name of the relation (RelationKey) for a SELECT * query. - */ - private RelationKey relationKey; - - /** - * Iterate over data from the JDBC database. - * */ private transient Iterator tuples; + /** * The result schema. * */ private final Schema outputSchema; /** - * The SQL template. - * */ - private String baseSQL; - - /** - * Column indexes that the output should be ordered by. + * The SQL query. */ - private final int[] sortedColumns; + private final String sql; /** - * True for each column in {@link #sortedColumns} that should be ordered ascending. + * The master catalog. */ - private final boolean[] ascending; + private final MasterCatalog catalog; /** Required for Java serialization. */ private static final long serialVersionUID = 1L; @@ -64,106 +45,21 @@ public class CatalogQueryScan extends LeafOperator implements DbReader { /** * Constructor. * - * @param baseSQL see the corresponding field. + * @param sql see the corresponding field. * @param outputSchema see the corresponding field. + * @param catalog see the corresponding field. * */ - public CatalogQueryScan(final String baseSQL, final Schema outputSchema) { - Objects.requireNonNull(baseSQL); - Objects.requireNonNull(outputSchema); - - this.baseSQL = baseSQL; - this.outputSchema = outputSchema; - connectionInfo = null; - tuples = null; - sortedColumns = null; - ascending = null; - } - - /** - * Constructor that receives the connection info as input. - * - * @param connectionInfo see the corresponding field. - * @param baseSQL see the corresponding field. - * @param outputSchema see the corresponding field. - * */ - public CatalogQueryScan(final ConnectionInfo connectionInfo, final String baseSQL, final Schema outputSchema) { - this(baseSQL, outputSchema); - Objects.requireNonNull(connectionInfo); - this.connectionInfo = connectionInfo; - } - - /** - * Construct a new DbQueryScan object that simply runs SELECT * FROM relationKey. - * - * @param relationKey the relation to be scanned. - * @param outputSchema the Schema of the returned tuples. - */ - public CatalogQueryScan(final RelationKey relationKey, final Schema outputSchema) { - Objects.requireNonNull(relationKey); - Objects.requireNonNull(outputSchema); - - this.relationKey = relationKey; - this.outputSchema = outputSchema; - baseSQL = null; - connectionInfo = null; - tuples = null; - sortedColumns = null; - ascending = null; - } - - /** - * Construct a new DbQueryScan object that simply runs SELECT * FROM relationKey, but receiving the - * connection info as input. - * - * @param connectionInfo the connection information. - * @param relationKey the relation to be scanned. - * @param outputSchema the Schema of the returned tuples. - */ - public CatalogQueryScan(final ConnectionInfo connectionInfo, final RelationKey relationKey, final Schema outputSchema) { - this(relationKey, outputSchema); - Objects.requireNonNull(connectionInfo); - this.connectionInfo = connectionInfo; - } - - /** - * Construct a new DbQueryScan object that runs SELECT * FROM relationKey ORDER BY [...]. - * - * @param relationKey the relation to be scanned. - * @param outputSchema the Schema of the returned tuples. - * @param sortedColumns the columns by which the tuples should be ordered by. - * @param ascending true for columns that should be ordered ascending. - */ - public CatalogQueryScan(final RelationKey relationKey, final Schema outputSchema, final int[] sortedColumns, - final boolean[] ascending) { - Objects.requireNonNull(relationKey); + public CatalogQueryScan(final String sql, final Schema outputSchema, final MasterCatalog catalog) { + Objects.requireNonNull(sql); Objects.requireNonNull(outputSchema); + Objects.requireNonNull(catalog); - this.relationKey = relationKey; + this.sql = sql; this.outputSchema = outputSchema; - this.sortedColumns = sortedColumns; - this.ascending = ascending; - baseSQL = null; - connectionInfo = null; + this.catalog = catalog; tuples = null; } - /** - * Construct a new DbQueryScan object that runs SELECT * FROM relationKey ORDER BY [...], but receiving - * the connection info as input. - * - * @param connectionInfo the connection information. - * @param relationKey the relation to be scanned. - * @param outputSchema the Schema of the returned tuples. - * @param sortedColumns the columns by which the tuples should be ordered by. - * @param ascending true for columns that should be ordered ascending. - */ - public CatalogQueryScan(final ConnectionInfo connectionInfo, final RelationKey relationKey, final Schema outputSchema, - final int[] sortedColumns, final boolean[] ascending) { - this(relationKey, outputSchema, sortedColumns, ascending); - Objects.requireNonNull(connectionInfo); - this.connectionInfo = connectionInfo; - } - @Override public final void cleanup() { tuples = null; @@ -171,11 +67,12 @@ public final void cleanup() { @Override protected final TupleBatch fetchNextReady() throws DbException { - Objects.requireNonNull(connectionInfo); if (tuples == null) { - tuples = - AccessMethod.of(connectionInfo.getDbms(), connectionInfo, true).tupleBatchIteratorFromQuery(baseSQL, - outputSchema); + try { + tuples = catalog.tupleBatchIteratorFromQuery(sql, outputSchema); + } catch (CatalogException e) { + throw new DbException(e); + } } if (tuples.hasNext()) { final TupleBatch tb = tuples.next(); @@ -193,60 +90,5 @@ public final Schema generateSchema() { @Override protected final void init(final ImmutableMap execEnvVars) throws DbException { - if (connectionInfo == null) { - final String dbms = (String) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_SYSTEM); - if (dbms == null) { - throw new DbException("Unable to instantiate DbQueryScan: database system unknown"); - } - - connectionInfo = (ConnectionInfo) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); - if (connectionInfo == null) { - throw new DbException("Unable to instantiate DbQueryScan: connection information unknown"); - } - - if (!dbms.equals(connectionInfo.getDbms())) { - throw new DbException( - "Unable to instantiate DbQueryScan: database system does not conform with connection information"); - } - } - - if (relationKey != null) { - baseSQL = "SELECT * FROM " + relationKey.toString(connectionInfo.getDbms()); - - String prefix = ""; - if (sortedColumns != null && sortedColumns.length > 0) { - Preconditions.checkArgument(sortedColumns.length == ascending.length); - StringBuilder orderByClause = new StringBuilder(" ORDER BY"); - - for (int columnIdx : sortedColumns) { - orderByClause.append(prefix + " " + getSchema().getColumnName(columnIdx)); - if (ascending[columnIdx]) { - orderByClause.append(" ASC"); - } else { - orderByClause.append(" DESC"); - } - - prefix = ","; - } - - baseSQL = baseSQL.concat(orderByClause.toString()); - } - } - } - - /** - * @return the connection info in this DbQueryScan. - */ - public ConnectionInfo getConnectionInfo() { - return connectionInfo; - } - - @Override - public Set readSet() { - if (relationKey == null) { - LOGGER.error("DbQueryScan does not support the DbReader interface properly for SQL queries."); - return ImmutableSet.of(); - } - return ImmutableSet.of(relationKey); } } diff --git a/src/edu/washington/escience/myria/parallel/Server.java b/src/edu/washington/escience/myria/parallel/Server.java index d9e17d454..495934ef2 100644 --- a/src/edu/washington/escience/myria/parallel/Server.java +++ b/src/edu/washington/escience/myria/parallel/Server.java @@ -1943,4 +1943,11 @@ public String getQueryPlan(@Nonnull final SubQueryId subQueryId) throws DbExcept throw new DbException(e); } } + + /** + * @return the master catalog. + */ + public MasterCatalog getCatalog() { + return catalog; + } } diff --git a/test/edu/washington/escience/myria/operator/CatalogScanTest.java b/test/edu/washington/escience/myria/operator/CatalogScanTest.java new file mode 100644 index 000000000..a980ffc9a --- /dev/null +++ b/test/edu/washington/escience/myria/operator/CatalogScanTest.java @@ -0,0 +1,66 @@ +/** + * + */ +package edu.washington.escience.myria.operator; + +import static org.junit.Assert.assertEquals; + +import java.util.logging.Level; +import java.util.logging.Logger; + +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; +import edu.washington.escience.myria.api.encoding.QueryEncoding; +import edu.washington.escience.myria.coordinator.catalog.CatalogException; +import edu.washington.escience.myria.coordinator.catalog.MasterCatalog; +import edu.washington.escience.myria.storage.TupleBatch; + +/** + * + */ +public class CatalogScanTest { + + /** + * The catalog + */ + private MasterCatalog catalog; + + /** + * @throws java.lang.Exception + */ + @Before + public void setUp() throws Exception { + /* Turn off SQLite logging, it's annoying. */ + Logger.getLogger("com.almworks.sqlite4java").setLevel(Level.OFF); + + catalog = MasterCatalog.createInMemory(); + + // add a query to the catalog + QueryEncoding query = new QueryEncoding(); + query.rawQuery = "query 1 is about baseball"; + query.logicalRa = ""; + catalog.newQuery(query); + } + + @Test + public final void testQueryQueries() throws DbException, CatalogException { + Schema schema = new Schema(ImmutableList.of(Type.LONG_TYPE, Type.STRING_TYPE), ImmutableList.of("id", "raw")); + CatalogQueryScan scan = new CatalogQueryScan("select query_id, raw_query from queries", schema, catalog); + scan.open(null); + + assertEquals(false, scan.eos()); + TupleBatch tb = scan.nextReady(); + assertEquals(false, scan.eos()); + assertEquals(1, tb.numTuples()); + assertEquals(schema, tb.getSchema()); + + tb = scan.nextReady(); + assertEquals(true, scan.eos()); + } +} From 16b4d2dfabcde9fa2716ae45d254c5dc2707e47f Mon Sep 17 00:00:00 2001 From: Dylan Hutchison Date: Wed, 3 Jun 2015 15:41:01 -0400 Subject: [PATCH 13/29] Add IntelliJ to .gitignore These are auto-generated files by the IntelliJ IDE that should not be stored in git repositories. --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index f33c5d202..7783f1d9d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ build *catalog-journal englink-log4j.log /bin + +### Intellij ### +*.iml +.idea/ From 09aa6aa56bb8a86ec8bf6bfb80c7941fa92e9e28 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Wed, 3 Jun 2015 16:32:50 -0700 Subject: [PATCH 14/29] some code cleanup of sampling operators --- .../SamplingDistributionEncoding.java | 3 +- .../escience/myria/operator/Sample.java | 69 ++++++++++--------- .../myria/operator/SamplingDistribution.java | 36 +++++----- .../escience/myria/util/SamplingType.java | 9 +++ .../operator/SamplingDistributionTest.java | 67 +++++++++--------- 5 files changed, 103 insertions(+), 81 deletions(-) create mode 100644 src/edu/washington/escience/myria/util/SamplingType.java diff --git a/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java index 323f17667..603c0ed95 100644 --- a/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/SamplingDistributionEncoding.java @@ -5,6 +5,7 @@ import edu.washington.escience.myria.api.MyriaApiException; import edu.washington.escience.myria.api.encoding.QueryConstruct.ConstructArgs; import edu.washington.escience.myria.operator.SamplingDistribution; +import edu.washington.escience.myria.util.SamplingType; public class SamplingDistributionEncoding extends UnaryOperatorEncoding { @@ -15,7 +16,7 @@ public class SamplingDistributionEncoding extends UnaryOperatorEncoding= samples.length) { + if (curSampIdx >= sampleIndices.length) { getRight().close(); - setEOS(); return null; } Operator right = getRight(); for (TupleBatch tb = right.nextReady(); tb != null; tb = right.nextReady()) { - if (curSampIdx >= samples.length) { // done sampling + if (curSampIdx >= sampleIndices.length) { // done sampling break; } - if (samples[curSampIdx] > tupleNum + tb.numTuples()) { + if (sampleIndices[curSampIdx] >= tuplesSeen + tb.numTuples()) { // nextIndex is not in this batch. Continue with next batch. - tupleNum += tb.numTuples(); + tuplesSeen += tb.numTuples(); continue; } - while (curSampIdx < samples.length - && samples[curSampIdx] < tupleNum + tb.numTuples()) { - ans.put(tb, samples[curSampIdx] - tupleNum); + while (curSampIdx < sampleIndices.length + && sampleIndices[curSampIdx] < tuplesSeen + tb.numTuples()) { + ans.put(tb, sampleIndices[curSampIdx] - tuplesSeen); curSampIdx++; } - tupleNum += tb.numTuples(); + tuplesSeen += tb.numTuples(); if (ans.hasFilledTB()) { return ans.popFilled(); } @@ -141,8 +142,8 @@ private void extractSamplingInfo(TupleBatch tb) throws Exception { throw new DbException("WorkerID column must be of type INT or LONG"); } Preconditions.checkState(workerID == getNodeID(), - "Invalid WorkerID for this worker. Expected %s, but received %s", - getNodeID(), workerID); + "Invalid WorkerID for this worker. Expected %s, but received %s", + getNodeID(), workerID); Type col1Type = tb.getSchema().getColumnType(1); if (col1Type == Type.INT_TYPE) { @@ -166,7 +167,12 @@ private void extractSamplingInfo(TupleBatch tb) throws Exception { Type col3Type = tb.getSchema().getColumnType(3); if (col3Type == Type.STRING_TYPE) { - sampleType = tb.getString(3, 0); + String col3Val = tb.getString(3, 0); + try { + sampleType = SamplingType.valueOf(col3Val); + } catch (IllegalArgumentException e) { + throw new DbException("Invalid SampleType: " + col3Val); + } } else { throw new DbException("SampleType column must be of type STRING"); } @@ -192,6 +198,8 @@ private int[] generateIndicesWR(int populationSize, int sampleSize) { /** * Generates a sorted array of unique random numbers to be taken as samples. + * The implementation uses Floyd's algorithm. For an explanation: + * www.nowherenearithaca.com/2013/05/robert-floyds-tiny-and-beautiful.html * * @param populationSize * size of the population that will be sampled from. @@ -209,12 +217,7 @@ private int[] generateIndicesWoR(int populationSize, int sampleSize) { indices.add(idx); } } - int[] indicesArr = new int[indices.size()]; - int i = 0; - for (Integer val : indices) { - indicesArr[i] = val; - i++; - } + int[] indicesArr = Ints.toArray(indices); Arrays.sort(indicesArr); return indicesArr; } @@ -231,6 +234,10 @@ public Schema generateSchema() { @Override protected void init(final ImmutableMap execEnvVars) { ans = new TupleBatchBuffer(getSchema()); + rand = new Random(); + if (randomSeed != null) { + rand.setSeed(randomSeed); + } } @Override diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 11f9eb4e9..99dc9d20b 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -14,14 +14,15 @@ import edu.washington.escience.myria.column.builder.IntColumnBuilder; import edu.washington.escience.myria.column.builder.StringColumnBuilder; import edu.washington.escience.myria.storage.TupleBatch; +import edu.washington.escience.myria.util.SamplingType; public class SamplingDistribution extends UnaryOperator { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; /** The output schema. */ - private static final Schema SCHEMA = Schema.of( - ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.STRING_TYPE), + private static final Schema SCHEMA = Schema.of(ImmutableList.of( + Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.STRING_TYPE), ImmutableList.of("WorkerID", "StreamSize", "SampleSize", "SampleType")); /** Total number of tuples to sample. */ @@ -33,13 +34,14 @@ public class SamplingDistribution extends UnaryOperator { /** Percentage of total tuples to sample. */ private float samplePercentage; - /** The type of sampling to perform. Currently supports 'WR' and 'WoR'. */ - private final String sampleType; + /** The type of sampling to perform. */ + private final SamplingType sampleType; /** Random generator used for creating the distribution. */ private Random rand; - private SamplingDistribution(Operator child, String sampleType, Long randomSeed) { + private SamplingDistribution(Operator child, SamplingType sampleType, + Long randomSeed) { super(child); this.sampleType = sampleType; this.rand = new Random(); @@ -62,7 +64,7 @@ private SamplingDistribution(Operator child, String sampleType, Long randomSeed) * value to seed the random generator with. null if no specified seed */ public SamplingDistribution(Operator child, int sampleSize, - String sampleType, Long randomSeed) { + SamplingType sampleType, Long randomSeed) { this(child, sampleType, randomSeed); this.sampleSize = sampleSize; Preconditions.checkState(this.sampleSize >= 0, @@ -70,7 +72,8 @@ public SamplingDistribution(Operator child, int sampleSize, } /** - * Instantiate a SamplingDistribution operator using a percentage of total tuples. + * Instantiate a SamplingDistribution operator using a percentage of total + * tuples. * * @param samplePercentage * percentage of total samples to create a distribution for. @@ -83,12 +86,12 @@ public SamplingDistribution(Operator child, int sampleSize, * value to seed the random generator with. null if no specified seed */ public SamplingDistribution(Operator child, float samplePercentage, - String sampleType, Long randomSeed) { + SamplingType sampleType, Long randomSeed) { this(child, sampleType, randomSeed); this.isPercentageSample = true; this.samplePercentage = samplePercentage; Preconditions.checkState(samplePercentage >= 0 && samplePercentage <= 100, - "Sample Percentage must be >= 0 && <= 100: %s", samplePercentage); + "Sample Percentage must be >= 0 && <= 100: %s", samplePercentage); } @Override @@ -162,14 +165,14 @@ protected TupleBatch fetchNextReady() throws DbException { throw new DbException("StreamSize must be of type INT or LONG"); } Preconditions.checkState(partitionSize >= 0, - "Worker cannot have a negative StreamSize: %d", streamSize); + "Worker cannot have a negative StreamSize: %d", streamSize); } streamCounts.set(workerID - 1, streamSize); } } // Convert samplePct to sampleSize if using a percentage sample. if (isPercentageSample) { - sampleSize = Math.round(totalTupleCount * (samplePercentage / 100)); + sampleSize = Math.round(totalTupleCount * (samplePercentage / 100)); } Preconditions.checkState(sampleSize >= 0 && sampleSize <= totalTupleCount, "Cannot extract %s samples from a population of size %s", sampleSize, @@ -177,9 +180,9 @@ protected TupleBatch fetchNextReady() throws DbException { // Generate a random distribution across the workers. int[] sampleCounts; - if (sampleType.equals("WR")) { + if (sampleType == SamplingType.WR) { sampleCounts = withReplacementDistribution(tupleCounts, sampleSize); - } else if (sampleType.equals("WoR")){ + } else if (sampleType == SamplingType.WoR) { sampleCounts = withoutReplacementDistribution(tupleCounts, sampleSize); } else { throw new DbException("Invalid sampleType: " + sampleType); @@ -194,10 +197,11 @@ protected TupleBatch fetchNextReady() throws DbException { wIdCol.appendInt(i + 1); streamSizeCol.appendInt(streamCounts.get(i)); sampCountCol.appendInt(sampleCounts[i]); - sampTypeCol.appendString(sampleType); + sampTypeCol.appendString(sampleType.name()); } ImmutableList.Builder> columns = ImmutableList.builder(); - columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), sampTypeCol.build()); + columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), + sampTypeCol.build()); return new TupleBatch(SCHEMA, columns.build()); } @@ -288,7 +292,7 @@ public float getSamplePercentage() { } /** Returns the type of sampling distribution that this operator will create. */ - public String getSampleType() { + public SamplingType getSampleType() { return sampleType; } diff --git a/src/edu/washington/escience/myria/util/SamplingType.java b/src/edu/washington/escience/myria/util/SamplingType.java new file mode 100644 index 000000000..943a5619c --- /dev/null +++ b/src/edu/washington/escience/myria/util/SamplingType.java @@ -0,0 +1,9 @@ +package edu.washington.escience.myria.util; + +/** + * Enumeration of supported sampling types. + */ +public enum SamplingType { + // WithReplacement, WithoutReplacement + WR, WoR +} diff --git a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java index 8a146c01e..0644849f2 100644 --- a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java +++ b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java @@ -2,6 +2,7 @@ import static org.junit.Assert.assertEquals; +import edu.washington.escience.myria.util.SamplingType; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +49,7 @@ public void setup() { @Test public void testSampleWRSizeZero() throws DbException { int sampleSize = 0; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -57,7 +58,7 @@ public void testSampleWRSizeZero() throws DbException { @Test public void testSampleWoRSizeZero() throws DbException { int sampleSize = 0; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -67,7 +68,7 @@ public void testSampleWoRSizeZero() throws DbException { @Test public void testSampleWRPctZero() throws DbException { float samplePct = 0; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -76,7 +77,7 @@ public void testSampleWRPctZero() throws DbException { @Test public void testSampleWoRPctZero() throws DbException { float samplePct = 0; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -86,14 +87,14 @@ public void testSampleWoRPctZero() throws DbException { @Test public void testSampleWRSizeOne() throws DbException { int sampleSize = 1; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeOne() throws DbException { int sampleSize = 1; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; verifyPossibleDistribution(sampleSize, sampleType); } @@ -101,14 +102,14 @@ public void testSampleWoRSizeOne() throws DbException { @Test public void testSampleWRSizeFifty() throws DbException { int sampleSize = 50; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeFifty() throws DbException { int sampleSize = 50; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; verifyPossibleDistribution(sampleSize, sampleType); } @@ -116,14 +117,14 @@ public void testSampleWoRSizeFifty() throws DbException { @Test public void testSampleWRPctFifty() throws DbException { float samplePct = 50; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(samplePct, sampleType); } @Test public void testSampleWoRPctFifty() throws DbException { float samplePct = 50; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; verifyPossibleDistribution(samplePct, sampleType); } @@ -131,14 +132,14 @@ public void testSampleWoRPctFifty() throws DbException { @Test public void testSampleWoRSizeAllButOne() throws DbException { int sampleSize = 999; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWRSizeAllButOne() throws DbException { int sampleSize = 999; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(sampleSize, sampleType); } @@ -146,7 +147,7 @@ public void testSampleWRSizeAllButOne() throws DbException { @Test public void testSampleWoRSizeMax() throws DbException { int sampleSize = 1000; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, { 3, 400, 400 }, { 4, 100, 100 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -155,7 +156,7 @@ public void testSampleWoRSizeMax() throws DbException { @Test public void testSampleWoRPctMax() throws DbException { float samplePct = 100; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, { 3, 400, 400 }, { 4, 100, 100 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -165,14 +166,14 @@ public void testSampleWoRPctMax() throws DbException { @Test public void testSampleWRSizeMax() throws DbException { int sampleSize = 1000; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWRPctMax() throws DbException { float samplePct = 100; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; verifyPossibleDistribution(samplePct, sampleType); } @@ -180,28 +181,28 @@ public void testSampleWRPctMax() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWoRSizeTooMany() throws DbException { int sampleSize = 1001; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(sampleSize, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWoRPctTooMany() throws DbException { float samplePct = 100.1f; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(samplePct, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRSizeTooMany() throws DbException { int sampleSize = 1001; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; drainOperator(sampleSize, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRPctTooMany() throws DbException { float samplePct = 100.1f; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; drainOperator(samplePct, sampleType); } @@ -209,28 +210,28 @@ public void testSampleWRPctTooMany() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWoRSizeNegative() throws DbException { int sampleSize = -1; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(sampleSize, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWoRPctNegative() throws DbException { float samplePct = -0.01f; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(samplePct, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRSizeNegative() throws DbException { int sampleSize = -1; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(sampleSize, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRPctNegative() throws DbException { float samplePct = -0.01f; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; drainOperator(samplePct, sampleType); } @@ -238,7 +239,7 @@ public void testSampleWRPctNegative() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWoRWorkerNegative() throws DbException { int sampleSize = 50; - String sampleType = "WoR"; + SamplingType sampleType = SamplingType.WoR; input.putInt(0, 5); input.putInt(1, -1); drainOperator(sampleSize, sampleType); @@ -247,7 +248,7 @@ public void testSampleWoRWorkerNegative() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWRWorkerNegative() throws DbException { int sampleSize = 50; - String sampleType = "WR"; + SamplingType sampleType = SamplingType.WR; input.putInt(0, 5); input.putInt(1, -1); drainOperator(sampleSize, sampleType); @@ -277,14 +278,14 @@ private void verifyExpectedResults(SamplingDistribution sampOp, } assertEquals(expectedResults.length, rowIdx); } - private void verifyExpectedResults(int sampleSize, String sampleType, + private void verifyExpectedResults(int sampleSize, SamplingType sampleType, int[][] expectedResults) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, sampleType, RANDOM_SEED); sampOp.open(TestEnvVars.get()); verifyExpectedResults(sampOp, expectedResults); } - private void verifyExpectedResults(float samplePct, String sampleType, + private void verifyExpectedResults(float samplePct, SamplingType sampleType, int[][] expectedResults) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), samplePct, sampleType, RANDOM_SEED); @@ -306,7 +307,7 @@ private void verifyPossibleDistribution(SamplingDistribution sampOp) assertEquals(expectedResultSchema, result.getSchema()); for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { assert (result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampOp.getSampleSize()); - if (sampOp.getSampleType().equals("WoR")) { + if (sampOp.getSampleType().equals(SamplingType.WoR)) { // SampleWoR cannot sample more than worker's population size. assert (result.getInt(2, i) <= result.getInt(1, i)); } @@ -318,7 +319,7 @@ private void verifyPossibleDistribution(SamplingDistribution sampOp) assertEquals(sampOp.getSampleSize(), computedSampleSize); } private void verifyPossibleDistribution(int sampleSize, - String sampleType) throws DbException { + SamplingType sampleType) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, sampleType, RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -326,7 +327,7 @@ private void verifyPossibleDistribution(int sampleSize, } private void verifyPossibleDistribution(float samplePct, - String sampleType) throws DbException { + SamplingType sampleType) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), samplePct, sampleType, RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -334,7 +335,7 @@ private void verifyPossibleDistribution(float samplePct, } /** Run through all results without doing anything. */ - private void drainOperator(int sampleSize, String sampleType) + private void drainOperator(int sampleSize, SamplingType sampleType) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), sampleSize, sampleType, RANDOM_SEED); @@ -344,7 +345,7 @@ private void drainOperator(int sampleSize, String sampleType) } } - private void drainOperator(float samplePct, String sampleType) + private void drainOperator(float samplePct, SamplingType sampleType) throws DbException { sampOp = new SamplingDistribution(new TupleSource(input), samplePct, sampleType, RANDOM_SEED); From a882e561831b98774abfb4382601525d78c8a032 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Wed, 3 Jun 2015 17:18:29 -0700 Subject: [PATCH 15/29] Fix findbugs --- .../washington/escience/myria/operator/CatalogQueryScan.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java index 5a25da552..cc3054d2b 100644 --- a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java +++ b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java @@ -34,7 +34,7 @@ public class CatalogQueryScan extends LeafOperator { /** * The master catalog. */ - private final MasterCatalog catalog; + private final transient MasterCatalog catalog; /** Required for Java serialization. */ private static final long serialVersionUID = 1L; From 047a6b19c8bdf8b3b178502f57171078ddb9e636 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Wed, 3 Jun 2015 21:15:23 -0700 Subject: [PATCH 16/29] fixed SamplingDistribution and general cleanup --- .../escience/myria/operator/Sample.java | 3 +- .../myria/operator/SamplingDistribution.java | 69 ++++++++++++------- .../operator/SamplingDistributionTest.java | 7 +- 3 files changed, 49 insertions(+), 30 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/Sample.java b/src/edu/washington/escience/myria/operator/Sample.java index e63558109..1c02f3ef7 100644 --- a/src/edu/washington/escience/myria/operator/Sample.java +++ b/src/edu/washington/escience/myria/operator/Sample.java @@ -71,8 +71,9 @@ protected TupleBatch fetchNextReady() throws Exception { // Extract sampling info from left operator. if (!computedSamplingInfo) { TupleBatch tb = getLeft().nextReady(); - if (tb == null) + if (tb == null) { return null; + } extractSamplingInfo(tb); getLeft().close(); diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 99dc9d20b..4734da2d4 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -6,6 +6,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; @@ -40,19 +41,35 @@ public class SamplingDistribution extends UnaryOperator { /** Random generator used for creating the distribution. */ private Random rand; + /** Seed for the random generator. */ + private Long randomSeed; + + /** + * Distribution of the tuples across the workers. Value at index i == # of + * tuples on worker i. + */ + ArrayList tupleCounts; + + /** + * Distribution of the actual stream size across the workers. May be different + * from tupleCounts if workers pre-sampled the data. Value at index i == # of + * tuples in stream on worker i. + */ + ArrayList streamCounts; + + /** Total number of tuples across all workers. */ + int totalTupleCount = 0; + private SamplingDistribution(Operator child, SamplingType sampleType, Long randomSeed) { super(child); this.sampleType = sampleType; - this.rand = new Random(); - if (randomSeed != null) { - this.rand.setSeed(randomSeed); - } + this.randomSeed = randomSeed; } /** * Instantiate a SamplingDistribution operator using a specific sample size. - * + * * @param sampleSize * total samples to create a distribution for. * @param sampleType @@ -96,27 +113,14 @@ public SamplingDistribution(Operator child, float samplePercentage, @Override protected TupleBatch fetchNextReady() throws DbException { - if (getChild().eos()) { - return null; - } - - // Distribution of the tuples across the workers. - // Value at index i == # of tuples on worker i. - ArrayList tupleCounts = new ArrayList(); - - // Distribution of the actual stream size across the workers. - // May be different from tupleCounts if worker i pre-sampled the data. - // Value at index i == # of tuples in stream on worker i. - ArrayList streamCounts = new ArrayList(); - - // Total number of tuples across all workers. - int totalTupleCount = 0; - // Drain out all the workerID and partitionSize info. while (!getChild().eos()) { TupleBatch tb = getChild().nextReady(); if (tb == null) { - continue; + if (getChild().eos()) { + break; + } + return null; } Type col0Type = tb.getSchema().getColumnType(0); Type col1Type = tb.getSchema().getColumnType(1); @@ -154,7 +158,6 @@ protected TupleBatch fetchNextReady() throws DbException { "Worker cannot have a negative PartitionSize: %s", partitionSize); tupleCounts.set(workerID - 1, partitionSize); totalTupleCount += partitionSize; - int streamSize = partitionSize; if (hasStreamSize) { if (col2Type == Type.INT_TYPE) { @@ -202,6 +205,8 @@ protected TupleBatch fetchNextReady() throws DbException { ImmutableList.Builder> columns = ImmutableList.builder(); columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), sampTypeCol.build()); + getChild().close(); + close(); return new TupleBatch(SCHEMA, columns.build()); } @@ -218,8 +223,9 @@ private int[] withReplacementDistribution(List tupleCounts, int sampleSize) { int[] distribution = new int[tupleCounts.size()]; int totalTupleCount = 0; - for (int val : tupleCounts) + for (int val : tupleCounts) { totalTupleCount += val; + } for (int i = 0; i < sampleSize; i++) { int sampleTupleIdx = rand.nextInt(totalTupleCount); @@ -249,9 +255,10 @@ private int[] withoutReplacementDistribution(List tupleCounts, int sampleSize) { int[] distribution = new int[tupleCounts.size()]; int totalTupleCount = 0; - for (int val : tupleCounts) + for (int val : tupleCounts) { totalTupleCount += val; - List logicalTupleCounts = new ArrayList(tupleCounts); + } + List logicalTupleCounts = new ArrayList<>(tupleCounts); for (int i = 0; i < sampleSize; i++) { int sampleTupleIdx = rand.nextInt(totalTupleCount - i); @@ -301,4 +308,14 @@ public Schema generateSchema() { return SCHEMA; } + @Override + protected void init(final ImmutableMap execEnvVars) { + rand = new Random(); + if (randomSeed != null) { + rand.setSeed(randomSeed); + } + tupleCounts = new ArrayList<>(); + streamCounts = new ArrayList<>(); + } + } diff --git a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java index 0644849f2..1c0584617 100644 --- a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java +++ b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java @@ -1,6 +1,7 @@ package edu.washington.escience.myria.operator; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import edu.washington.escience.myria.util.SamplingType; import org.junit.After; @@ -306,10 +307,10 @@ private void verifyPossibleDistribution(SamplingDistribution sampOp) if (result != null) { assertEquals(expectedResultSchema, result.getSchema()); for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { - assert (result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampOp.getSampleSize()); - if (sampOp.getSampleType().equals(SamplingType.WoR)) { + assertTrue(result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampOp.getSampleSize()); + if (sampOp.getSampleType() == SamplingType.WoR) { // SampleWoR cannot sample more than worker's population size. - assert (result.getInt(2, i) <= result.getInt(1, i)); + assertTrue(result.getInt(2, i) <= result.getInt(1, i)); } computedSampleSize += result.getInt(2, i); } From 7551d606d5691c25ceaae063a0668f7b3bf86846 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Wed, 3 Jun 2015 22:24:07 -0700 Subject: [PATCH 17/29] Need to run all queries inside the queue thread --- .../myria/coordinator/catalog/MasterCatalog.java | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java index 1fdbae9fd..84d8ab8f1 100644 --- a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java +++ b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java @@ -18,8 +18,6 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; -import jersey.repackaged.com.google.common.collect.ImmutableSet; - import org.joda.time.DateTime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -1998,12 +1996,17 @@ protected String job(final SQLiteConnection sqliteConnection) throws CatalogExce public Iterator tupleBatchIteratorFromQuery(final String queryString, final Schema outputSchema) throws CatalogException { try { - return queue.execute(new SQLiteJob() { + return queue.execute(new SQLiteJob>() { @Override - protected SQLiteTupleBatchIterator job(final SQLiteConnection sqliteConnection) throws CatalogException, + protected Iterator job(final SQLiteConnection sqliteConnection) throws CatalogException, SQLiteException { SQLiteStatement statement = sqliteConnection.prepare(queryString); - return new SQLiteTupleBatchIterator(statement, sqliteConnection, outputSchema); + List tuples = Lists.newLinkedList(); + Iterator iter = new SQLiteTupleBatchIterator(statement, sqliteConnection, outputSchema); + while (iter.hasNext()) { + tuples.add(iter.next()); + } + return tuples.iterator(); } }).get(); } catch (InterruptedException | ExecutionException e) { From 6780c98e931686ad729ba0d7135f0a33010667c7 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Wed, 3 Jun 2015 22:24:30 -0700 Subject: [PATCH 18/29] Close the catalog after the test. Not really needed but nice to do (and test). --- .../escience/myria/coordinator/catalog/CatalogTest.java | 6 ++++++ .../escience/myria/operator/CatalogScanTest.java | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/test/edu/washington/escience/myria/coordinator/catalog/CatalogTest.java b/test/edu/washington/escience/myria/coordinator/catalog/CatalogTest.java index 90daa9c7c..a82536ca9 100644 --- a/test/edu/washington/escience/myria/coordinator/catalog/CatalogTest.java +++ b/test/edu/washington/escience/myria/coordinator/catalog/CatalogTest.java @@ -84,6 +84,8 @@ public void testCatalogCreation() throws CatalogException { for (final String worker : WORKERS) { assertTrue(values.contains(SocketInfo.valueOf(worker))); } + + catalog.close(); } /** @@ -252,6 +254,8 @@ public void testCatalogQuerySearch() throws CatalogException { assertEquals(2, queries.size()); assertEquals(Long.valueOf(3L), queries.get(0).queryId); assertEquals(Long.valueOf(2L), queries.get(1).queryId); + + catalog.close(); } /** @@ -342,5 +346,7 @@ public void testCatalogExtraFieldsList() throws CatalogException { assertEquals(null, qs.logicalRa); assertEquals(ImmutableSet.copyOf(qs.profilingMode), ImmutableSet.copyOf(query.profilingMode)); assertEquals(qs.language, query.language); + + catalog.close(); } } diff --git a/test/edu/washington/escience/myria/operator/CatalogScanTest.java b/test/edu/washington/escience/myria/operator/CatalogScanTest.java index a980ffc9a..1d8262e67 100644 --- a/test/edu/washington/escience/myria/operator/CatalogScanTest.java +++ b/test/edu/washington/escience/myria/operator/CatalogScanTest.java @@ -8,6 +8,7 @@ import java.util.logging.Level; import java.util.logging.Logger; +import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,6 +49,14 @@ public void setUp() throws Exception { catalog.newQuery(query); } + /** + * Destroy catalog. + */ + @After + public void Cleanup() { + catalog.close(); + } + @Test public final void testQueryQueries() throws DbException, CatalogException { Schema schema = new Schema(ImmutableList.of(Type.LONG_TYPE, Type.STRING_TYPE), ImmutableList.of("id", "raw")); From d179d0414b697fd4acf823a0dc6773223ff49582 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Wed, 3 Jun 2015 22:39:46 -0700 Subject: [PATCH 19/29] Remove unused code --- .../accessmethod/SQLiteTupleBatchIterator.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java index 24a6fe1b2..7ded88784 100644 --- a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java @@ -34,19 +34,6 @@ public class SQLiteTupleBatchIterator implements Iterator { /** The Schema of the TupleBatches returned by this Iterator. */ private final Schema schema; - /** - * Wraps a SQLiteStatement result set in an Iterator. - * - * @param statement the SQLiteStatement containing the results. - * @param schema the Schema describing the format of the TupleBatch containing these results. - * @param connection the connection to the SQLite database. - */ - SQLiteTupleBatchIterator(final SQLiteStatement statement, final Schema schema, final SQLiteConnection connection) { - this.statement = statement; - this.connection = connection; - this.schema = schema; - } - /** * Wraps a SQLiteStatement result set in an Iterator. * From e570e45a4701a17ad2b54c15d64e78aeb473aa6e Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Wed, 3 Jun 2015 23:49:42 -0700 Subject: [PATCH 20/29] bug fix for sampling distribution --- .../washington/escience/myria/operator/SamplingDistribution.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 4734da2d4..4e1ad12a8 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -205,7 +205,6 @@ protected TupleBatch fetchNextReady() throws DbException { ImmutableList.Builder> columns = ImmutableList.builder(); columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), sampTypeCol.build()); - getChild().close(); close(); return new TupleBatch(SCHEMA, columns.build()); } From 5086b8624f72e5210af2b6c2c40af8d49fe2129d Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Wed, 3 Jun 2015 23:58:26 -0700 Subject: [PATCH 21/29] set eos instead of close --- .../escience/myria/operator/SamplingDistribution.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 4e1ad12a8..a08d7f81c 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -205,7 +205,7 @@ protected TupleBatch fetchNextReady() throws DbException { ImmutableList.Builder> columns = ImmutableList.builder(); columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), sampTypeCol.build()); - close(); + setEOS(); return new TupleBatch(SCHEMA, columns.build()); } From 7d6f44f5f630c87d055979bb17a03b78ce05157f Mon Sep 17 00:00:00 2001 From: Yuqing Guo Date: Fri, 5 Jun 2015 00:20:41 -0700 Subject: [PATCH 22/29] change in fetchNextReady() to return filled tuple batch as soon as one is constructed --- .../operator/agg/StreamingAggregate.java | 138 ++++-------------- .../myria/operator/StreamingAggTest.java | 47 ++++++ 2 files changed, 78 insertions(+), 107 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java index 586279983..ec404c877 100644 --- a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java +++ b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java @@ -6,17 +6,15 @@ import javax.annotation.Nullable; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; -import edu.washington.escience.myria.Type; import edu.washington.escience.myria.operator.Operator; import edu.washington.escience.myria.operator.UnaryOperator; import edu.washington.escience.myria.storage.Tuple; import edu.washington.escience.myria.storage.TupleBatch; -import edu.washington.escience.myria.storage.TupleBuffer; +import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.storage.TupleUtils; /** @@ -33,17 +31,22 @@ public class StreamingAggregate extends UnaryOperator { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** The schema of the aggregation result. */ - private Schema aggSchema; + /** The schema of the result. */ + private Schema resultSchema; /** The schema of the columns indicated by the group keys. */ private Schema groupSchema; + /** Holds the current grouping key. */ private Tuple curGroupKey; + /** Child of this aggregator. **/ + private final Operator child = getChild(); + /** Currently processing input tuple batch. **/ + private TupleBatch tb; + /** Current row in the input tuple batch. **/ + private int row; /** Group fields. **/ private final int[] gFields; - /** Group field types. **/ - private final Type[] gTypes; /** An array [0, 1, .., gFields.length-1] used for comparing tuples. */ private final int[] gRange; /** Factories to make the Aggregators. **/ @@ -54,9 +57,7 @@ public class StreamingAggregate extends UnaryOperator { private Object[] aggregatorStates; /** Buffer for holding intermediate results. */ - private transient TupleBuffer resultBuffer; - /** Buffer for holding finished results as tuple batches. */ - private transient ImmutableList finalBuffer; + private transient TupleBatchBuffer resultBuffer; /** * Groups the input tuples according to the specified grouping fields, then produces the specified aggregates. @@ -69,12 +70,11 @@ public StreamingAggregate(@Nullable final Operator child, @Nonnull final int[] g @Nonnull final AggregatorFactory... factories) { super(child); gFields = Objects.requireNonNull(gfields, "gfields"); - gTypes = new Type[gfields.length]; this.factories = Objects.requireNonNull(factories, "factories"); Preconditions.checkArgument(gfields.length > 0, " must have at least one group by field"); Preconditions.checkArgument(factories.length > 0, "to use StreamingAggregate, must specify some aggregates"); gRange = new int[gfields.length]; - for (int i = 0; i < gfields.length; ++i) { + for (int i = 0; i < gRange.length; ++i) { gRange[i] = i; } } @@ -88,81 +88,42 @@ public StreamingAggregate(@Nullable final Operator child, @Nonnull final int[] g */ @Override protected TupleBatch fetchNextReady() throws DbException { - final Operator child = getChild(); if (child.eos()) { - return getResultBatch(); + return resultBuffer.popAny(); + } + if (tb == null) { + tb = child.nextReady(); + row = 0; } - - TupleBatch tb = child.nextReady(); while (tb != null) { - for (int row = 0; row < tb.numTuples(); ++row) { + while (row < tb.numTuples()) { if (curGroupKey == null) { /* First time accessing this tb, no aggregation performed previously. */ // store current group key as a tuple curGroupKey = new Tuple(groupSchema); for (int gKey = 0; gKey < gFields.length; ++gKey) { - gTypes[gKey] = tb.getSchema().getColumnType(gFields[gKey]); - switch (gTypes[gKey]) { - case BOOLEAN_TYPE: - curGroupKey.set(gKey, tb.getBoolean(gFields[gKey], row)); - break; - case STRING_TYPE: - curGroupKey.set(gKey, tb.getString(gFields[gKey], row)); - break; - case DATETIME_TYPE: - curGroupKey.set(gKey, tb.getDateTime(gFields[gKey], row)); - break; - case INT_TYPE: - curGroupKey.set(gKey, tb.getInt(gFields[gKey], row)); - break; - case LONG_TYPE: - curGroupKey.set(gKey, tb.getLong(gFields[gKey], row)); - break; - case FLOAT_TYPE: - curGroupKey.set(gKey, tb.getFloat(gFields[gKey], row)); - break; - case DOUBLE_TYPE: - curGroupKey.set(gKey, tb.getDouble(gFields[gKey], row)); - break; - } + TupleUtils.copyValue(tb, gFields[gKey], row, curGroupKey, gKey); } } else if (!TupleUtils.tupleEquals(tb, gFields, row, curGroupKey, gRange, 0)) { /* Different grouping key than current one, flush current agg result to result buffer. */ addToResult(); // store current group key as a tuple for (int gKey = 0; gKey < gFields.length; ++gKey) { - switch (gTypes[gKey]) { - case BOOLEAN_TYPE: - curGroupKey.set(gKey, tb.getBoolean(gFields[gKey], row)); - break; - case STRING_TYPE: - curGroupKey.set(gKey, tb.getString(gFields[gKey], row)); - break; - case DATETIME_TYPE: - curGroupKey.set(gKey, tb.getDateTime(gFields[gKey], row)); - break; - case INT_TYPE: - curGroupKey.set(gKey, tb.getInt(gFields[gKey], row)); - break; - case LONG_TYPE: - curGroupKey.set(gKey, tb.getLong(gFields[gKey], row)); - break; - case FLOAT_TYPE: - curGroupKey.set(gKey, tb.getFloat(gFields[gKey], row)); - break; - case DOUBLE_TYPE: - curGroupKey.set(gKey, tb.getDouble(gFields[gKey], row)); - break; - } + TupleUtils.copyValue(tb, gFields[gKey], row, curGroupKey, gKey); + } + aggregatorStates = AggUtils.allocateAggStates(aggregators); + if (resultBuffer.hasFilledTB()) { + return resultBuffer.popFilled(); } - reinitializeAggStates(); } // update aggregator states with current tuple for (int agg = 0; agg < aggregators.length; ++agg) { aggregators[agg].addRow(tb, row, aggregatorStates[agg]); } + row++; } tb = child.nextReady(); + row = 0; } /* @@ -171,21 +132,11 @@ protected TupleBatch fetchNextReady() throws DbException { */ if (child.eos()) { addToResult(); - return getResultBatch(); + return resultBuffer.popAny(); } return null; } - /** - * Re-initialize aggregator states for new group key. - * - * @throws DbException if any error - */ - private void reinitializeAggStates() throws DbException { - aggregatorStates = null; - aggregatorStates = AggUtils.allocateAggStates(aggregators); - } - /** * Add aggregate results with previous grouping key to result buffer. * @@ -202,26 +153,6 @@ private void addToResult() throws DbException { } } - /** - * @return A batch's worth of result tuples from this aggregate. - * @throws DbException if there is an error. - */ - private TupleBatch getResultBatch() throws DbException { - Preconditions.checkState(getChild().eos(), "cannot extract results from an aggregate until child has reached EOS"); - if (finalBuffer == null) { - finalBuffer = resultBuffer.finalResult(); - if (resultBuffer.numTuples() == 0) { - throw new DbException("0 tuples in result buffer"); - } - resultBuffer = null; - } - if (finalBuffer.isEmpty()) { - return null; - } else { - return finalBuffer.get(0); - } - } - /** * The schema of the aggregate output. Grouping fields first and then aggregate fields. * @@ -239,22 +170,16 @@ protected Schema generateSchema() { } groupSchema = inputSchema.getSubSchema(gFields); - - /* Build the output schema from the group schema and the aggregates. */ - final ImmutableList.Builder aggTypes = ImmutableList. builder(); - final ImmutableList.Builder aggNames = ImmutableList. builder(); - + resultSchema = groupSchema; try { for (Aggregator agg : AggUtils.allocateAggs(factories, inputSchema)) { Schema curAggSchema = agg.getResultSchema(); - aggTypes.addAll(curAggSchema.getColumnTypes()); - aggNames.addAll(curAggSchema.getColumnNames()); + resultSchema = Schema.merge(resultSchema, curAggSchema); } } catch (DbException e) { throw new RuntimeException("unable to allocate aggregators to determine output schema", e); } - aggSchema = new Schema(aggTypes, aggNames); - return Schema.merge(groupSchema, aggSchema); + return resultSchema; } @Override @@ -262,7 +187,7 @@ protected void init(final ImmutableMap execEnvVars) throws DbExc Preconditions.checkState(getSchema() != null, "unable to determine schema in init"); aggregators = AggUtils.allocateAggs(factories, getChild().getSchema()); aggregatorStates = AggUtils.allocateAggStates(aggregators); - resultBuffer = new TupleBuffer(getSchema()); + resultBuffer = new TupleBatchBuffer(getSchema()); } @Override @@ -270,6 +195,5 @@ protected void cleanup() throws DbException { aggregatorStates = null; curGroupKey = null; resultBuffer = null; - finalBuffer = null; } } diff --git a/test/edu/washington/escience/myria/operator/StreamingAggTest.java b/test/edu/washington/escience/myria/operator/StreamingAggTest.java index b31813913..c52e57274 100644 --- a/test/edu/washington/escience/myria/operator/StreamingAggTest.java +++ b/test/edu/washington/escience/myria/operator/StreamingAggTest.java @@ -1,7 +1,10 @@ package edu.washington.escience.myria.operator; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import org.joda.time.DateTime; import org.junit.Test; @@ -1090,4 +1093,48 @@ public void testMultiGroupAllAggLargeInput() throws DbException { assertEquals(expectedFourthStdev, result.getDouble(7, 3), 0.0001); agg.close(); } + + @Test + public void testMultiBatchResult() throws DbException { + final int numTuples = 3 * TupleBatch.BATCH_SIZE + 3; + final Schema schema = Schema.ofFields(Type.LONG_TYPE, "gkey", Type.LONG_TYPE, "value"); + final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); + // gkey: 0, 1, 2, ..., numTuples-1; value: 1, 1, 1, ... + for (long i = 0; i < numTuples; i++) { + tbb.putLong(0, i); + tbb.putLong(1, 1L); + } + // group by col0, count + StreamingAggregate agg = + new StreamingAggregate(new TupleSource(tbb), new int[] { 0 }, new SingleColumnAggregatorFactory(1, + AggregationOp.COUNT)); + agg.open(null); + TupleBatch result = agg.nextReady(); + assertNotNull(result); + assertEquals(TupleBatch.BATCH_SIZE, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + // aggregator should return filled tuple batch, even if it hasn't finished processing all input + assertFalse(agg.getChild().eos()); + // get second tuple batch + result = agg.nextReady(); + assertEquals(TupleBatch.BATCH_SIZE, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + assertFalse(agg.getChild().eos()); + // get third tuple batch + result = agg.nextReady(); + assertEquals(TupleBatch.BATCH_SIZE, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + assertFalse(agg.getChild().eos()); + // get last, non-filled tuple batch + result = agg.nextReady(); + assertEquals(3, result.numTuples()); + assertEquals(2, result.getSchema().numColumns()); + // child reaches eos() + assertTrue(agg.getChild().eos()); + // exhaust aggregator + result = agg.nextReady(); + assertNull(result); + assertTrue(agg.eos()); + agg.close(); + } } \ No newline at end of file From 7d837b43222b1ceacc14f124d4ace17c15ee93c0 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Fri, 5 Jun 2015 15:52:52 -0700 Subject: [PATCH 23/29] Revert "Remove unused code" This reverts commit d179d0414b697fd4acf823a0dc6773223ff49582. --- .../accessmethod/SQLiteTupleBatchIterator.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java index 7ded88784..24a6fe1b2 100644 --- a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java @@ -34,6 +34,19 @@ public class SQLiteTupleBatchIterator implements Iterator { /** The Schema of the TupleBatches returned by this Iterator. */ private final Schema schema; + /** + * Wraps a SQLiteStatement result set in an Iterator. + * + * @param statement the SQLiteStatement containing the results. + * @param schema the Schema describing the format of the TupleBatch containing these results. + * @param connection the connection to the SQLite database. + */ + SQLiteTupleBatchIterator(final SQLiteStatement statement, final Schema schema, final SQLiteConnection connection) { + this.statement = statement; + this.connection = connection; + this.schema = schema; + } + /** * Wraps a SQLiteStatement result set in an Iterator. * From 9474f6edabd158011849ccd76167fbafad3fd4e2 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sat, 6 Jun 2015 16:21:35 -0700 Subject: [PATCH 24/29] refactoring and cleanup of sampling operators --- .../encoding/SampledDbInsertTempEncoding.java | 6 +- .../escience/myria/operator/DbInsertTemp.java | 2 +- .../escience/myria/operator/Sample.java | 24 +- .../myria/operator/SampledDbInsertTemp.java | 237 ++++++--------- .../myria/operator/SamplingDistribution.java | 286 +++++++++--------- .../IdentityHashPartitionFunction.java | 20 +- .../escience/myria/util/SamplingType.java | 16 +- .../escience/myria/operator/SampleWRTest.java | 4 +- .../myria/operator/SampleWoRTest.java | 4 +- .../operator/SamplingDistributionTest.java | 66 ++-- 10 files changed, 316 insertions(+), 349 deletions(-) diff --git a/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java b/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java index 2cc10ae85..a917ac743 100644 --- a/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java +++ b/src/edu/washington/escience/myria/api/encoding/SampledDbInsertTempEncoding.java @@ -17,6 +17,10 @@ public class SampledDbInsertTempEncoding extends UnaryOperatorEncoding indices = new HashSet(sampleSize); + Set indices = new HashSet<>(sampleSize); for (int i = populationSize - sampleSize; i < populationSize; i++) { int idx = rand.nextInt(i + 1); if (indices.contains(idx)) { @@ -235,10 +239,6 @@ public Schema generateSchema() { @Override protected void init(final ImmutableMap execEnvVars) { ans = new TupleBatchBuffer(getSchema()); - rand = new Random(); - if (randomSeed != null) { - rand.setSeed(randomSeed); - } } @Override diff --git a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java index eac978f57..d993a1c75 100644 --- a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java +++ b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java @@ -3,19 +3,19 @@ */ package edu.washington.escience.myria.operator; -import java.io.File; -import java.util.*; +import java.util.List; +import java.util.Map; +import java.util.Random; -import com.almworks.sqlite4java.SQLiteConnection; -import com.almworks.sqlite4java.SQLiteException; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import edu.washington.escience.myria.*; -import edu.washington.escience.myria.accessmethod.AccessMethod; +import edu.washington.escience.myria.DbException; +import edu.washington.escience.myria.RelationKey; +import edu.washington.escience.myria.Schema; +import edu.washington.escience.myria.Type; import edu.washington.escience.myria.accessmethod.ConnectionInfo; -import edu.washington.escience.myria.accessmethod.SQLiteInfo; import edu.washington.escience.myria.column.Column; import edu.washington.escience.myria.column.builder.IntColumnBuilder; import edu.washington.escience.myria.parallel.RelationWriteMetadata; @@ -25,197 +25,132 @@ /** * Samples the stream into a temp relation. */ -public class SampledDbInsertTemp extends UnaryOperator implements DbWriter { +public class SampledDbInsertTemp extends DbInsertTemp implements DbWriter { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; - /** The connection to the database database. */ - private AccessMethod accessMethod; - /** The information for the database connection. */ - private ConnectionInfo connectionInfo; - /** The name of the table the tuples should be inserted into. */ - private final RelationKey sampleRelationKey; /** The name of the table the tuples should be inserted into. */ private final RelationKey countRelationKey; - /** Total number of tuples seen from the child. */ - private int tupleCount = 0; + /** Number of tuples seen so far from the child. */ + private int currentTupleCount; + /** Number of tuples to sample from the stream. */ - private final int streamSampleSize; + private final int sampleSize; + /** Reservoir that holds sampleSize number of tuples. */ - private MutableTupleBuffer reservoir = null; - /** Sampled tuples ready to be returned. */ - private List batches; - /** Next element of batches List that will be returned. */ - private int batchNum = 0; - /** True if all samples have been gathered from the child. */ - private boolean doneSamplingFromChild; - - /** The output schema. */ - private static final Schema COUNT_SCHEMA = Schema.of( - ImmutableList.of(Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE), - ImmutableList.of("WorkerID", "PartitionSize", "StreamSize")); + private MutableTupleBuffer reservoir; + + /** Random generator used for creating the distribution. */ + private Random rand; + + /** Schema that will be written to the countRelationKey. */ + private static final Schema COUNT_SCHEMA = Schema.ofFields("WorkerID", + Type.INT_TYPE, "PartitionSize", Type.INT_TYPE, "PartitionSampleSize", + Type.INT_TYPE); /** + * * @param child - * the source of tuples to be inserted. - * @param streamSampleSize + * the source of tuples to be inserted + * @param sampleSize * number of tuples to store from the stream * @param sampleRelationKey - * the key of the table that the tuples should be inserted into. + * the key of the table that tuples will be inserted into * @param countRelationKey - * the key of the table that the tuple counts will be inserted into. + * the key of the table that tuple count info will be inserted into * @param connectionInfo - * the parameters of the database connection. + * parameters of the database connection + * @param randomSeed + * value to seed the random generator with. null if no specified seed */ - public SampledDbInsertTemp(final Operator child, final int streamSampleSize, + public SampledDbInsertTemp(final Operator child, final int sampleSize, final RelationKey sampleRelationKey, final RelationKey countRelationKey, - final ConnectionInfo connectionInfo) { - super(child); - // Sampling setup. - Preconditions.checkArgument(streamSampleSize >= 0L, + final ConnectionInfo connectionInfo, Long randomSeed) { + super(child, sampleRelationKey, connectionInfo, true, null); + Preconditions.checkArgument(sampleSize >= 0, "sampleSize must be non-negative"); - this.streamSampleSize = streamSampleSize; - doneSamplingFromChild = false; - - // Relation setup. - Objects.requireNonNull(sampleRelationKey, "sampleRelationKey"); - this.sampleRelationKey = sampleRelationKey; - Objects.requireNonNull(countRelationKey, "countRelationKey"); + this.sampleSize = sampleSize; + Preconditions.checkNotNull(countRelationKey, + "countRelationKey cannot be null"); this.countRelationKey = countRelationKey; - this.connectionInfo = connectionInfo; - } - - @Override - protected TupleBatch fetchNextReady() throws DbException { - if (!doneSamplingFromChild) { - fillReservoir(); - batches = reservoir.getAll(); - // Insert sampled tuples into sampleRelationKey - while (batchNum < batches.size()) { - TupleBatch batch = batches.get(batchNum); - accessMethod.tupleBatchInsert(sampleRelationKey, batch); - batchNum++; - } - - // Write (WorkerID, PartitionSize, StreamSize) to countRelationKey - IntColumnBuilder wIdCol = new IntColumnBuilder(); - IntColumnBuilder tupCountCol = new IntColumnBuilder(); - IntColumnBuilder streamSizeCol = new IntColumnBuilder(); - wIdCol.appendInt(getNodeID()); - tupCountCol.appendInt(tupleCount); - streamSizeCol.appendInt(reservoir.numTuples()); - ImmutableList.Builder> columns = ImmutableList.builder(); - columns.add(wIdCol.build(), tupCountCol.build(), streamSizeCol.build()); - TupleBatch tb = new TupleBatch(COUNT_SCHEMA, columns.build()); - accessMethod.tupleBatchInsert(countRelationKey, tb); + rand = new Random(); + if (randomSeed != null) { + rand.setSeed(randomSeed); } - return null; } /** - * Fills reservoir with child tuples. - * - * @throws DbException - * if TupleBatch fails to get nextReady + * Uses reservoir sampling to insert the specified sampleSize. + * https://en.wikipedia.org/wiki/Reservoir_sampling */ - private void fillReservoir() throws DbException { - Random rand = new Random(); - for (TupleBatch tb = getChild().nextReady(); tb != null; tb = getChild() - .nextReady()) { - final List> columns = tb.getDataColumns(); - for (int i = 0; i < tb.numTuples(); i++) { - if (reservoir.numTuples() < streamSampleSize) { - // Reservoir size < k. Add this tuple. + @Override + protected void consumeTuples(final TupleBatch tb) throws DbException { + final List> columns = tb.getDataColumns(); + for (int i = 0; i < tb.numTuples(); i++) { + if (reservoir.numTuples() < sampleSize) { + // Reservoir size < k. Add this tuple. + for (int j = 0; j < tb.numColumns(); j++) { + reservoir.put(j, columns.get(j), i); + } + } else { + // Replace probabilistically + int replaceIdx = rand.nextInt(currentTupleCount); + if (replaceIdx < sampleSize) { for (int j = 0; j < tb.numColumns(); j++) { - reservoir.put(j, columns.get(j), i); - } - } else { - // Replace probabilistically - int replaceIdx = rand.nextInt(tupleCount); - if (replaceIdx < reservoir.numTuples()) { - for (int j = 0; j < tb.numColumns(); j++) { - reservoir.replace(j, replaceIdx, columns.get(j), i); - } + reservoir.replace(j, replaceIdx, columns.get(j), i); } } - tupleCount++; } + currentTupleCount++; } - doneSamplingFromChild = true; } @Override - protected void init(final ImmutableMap execEnvVars) - throws DbException { - reservoir = new MutableTupleBuffer(getChild().getSchema()); - /* - * retrieve connection information from the environment variables, if not - * already set - */ - if (connectionInfo == null && execEnvVars != null) { - connectionInfo = (ConnectionInfo) execEnvVars - .get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); - } - - if (connectionInfo == null) { - throw new DbException( - "Unable to instantiate SampledDbInsertTemp: connection information unknown"); + protected void childEOS() throws DbException { + // Insert the reservoir samples. + for (TupleBatch tb : reservoir.getAll()) { + accessMethod.tupleBatchInsert(getRelationKey(), tb); } + // Insert (WorkerID, PartitionSize, PartitionSampleSize) to + // countRelationKey. + IntColumnBuilder wIdCol = new IntColumnBuilder(); + IntColumnBuilder tupCountCol = new IntColumnBuilder(); + IntColumnBuilder sampledSizeCol = new IntColumnBuilder(); + wIdCol.appendInt(getNodeID()); + tupCountCol.appendInt(currentTupleCount); + sampledSizeCol.appendInt(reservoir.numTuples()); + ImmutableList.Builder> columns = ImmutableList.builder(); + columns.add(wIdCol.build(), tupCountCol.build(), sampledSizeCol.build()); + TupleBatch tb = new TupleBatch(COUNT_SCHEMA, columns.build()); + accessMethod.tupleBatchInsert(countRelationKey, tb); + } - if (connectionInfo instanceof SQLiteInfo) { - /* Set WAL in the beginning. */ - final File dbFile = new File( - ((SQLiteInfo) connectionInfo).getDatabaseFilename()); - SQLiteConnection conn = new SQLiteConnection(dbFile); - try { - conn.open(true); - conn.exec("PRAGMA journal_mode=WAL;"); - } catch (SQLiteException e) { - e.printStackTrace(); - } - conn.dispose(); - } + @Override + protected void init(final ImmutableMap execEnvVars) + throws DbException { + // Will set up the database connection and create the reservoir table. + super.init(execEnvVars); - /* open the database connection */ - accessMethod = AccessMethod.of(connectionInfo.getDbms(), connectionInfo, - false); - accessMethod.dropTableIfExists(sampleRelationKey); + // Set up the tuple count table. accessMethod.dropTableIfExists(countRelationKey); - // Create the temp tables. - accessMethod.createTableIfNotExists(sampleRelationKey, getSchema()); accessMethod.createTableIfNotExists(countRelationKey, COUNT_SCHEMA); + + reservoir = new MutableTupleBuffer(getChild().getSchema()); } @Override public void cleanup() { + super.cleanup(); reservoir = null; - batches = null; - - try { - if (accessMethod != null) { - accessMethod.close(); - } - } catch (DbException e) { - throw new RuntimeException(e); - } - } - - @Override - public final Schema generateSchema() { - if (getChild() == null) { - return null; - } - return getChild().getSchema(); } @Override public Map writeSet() { - Map map = new HashMap(2); - map.put(sampleRelationKey, new RelationWriteMetadata(sampleRelationKey, getSchema(), true, true)); - map.put(countRelationKey, new RelationWriteMetadata(countRelationKey, COUNT_SCHEMA, true, true)); - return map; + return ImmutableMap.of(getRelationKey(), new RelationWriteMetadata( + getRelationKey(), getSchema(), true, true), countRelationKey, + new RelationWriteMetadata(countRelationKey, COUNT_SCHEMA, true, true)); } } diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index a08d7f81c..ae837ea82 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -1,12 +1,11 @@ package edu.washington.escience.myria.operator; -import java.util.ArrayList; -import java.util.List; +import java.util.Map; import java.util.Random; +import java.util.TreeMap; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; @@ -17,17 +16,21 @@ import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.util.SamplingType; +/** + * Given the sizes of each worker, computes a distribution of how much each + * worker should sample. + */ public class SamplingDistribution extends UnaryOperator { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; /** The output schema. */ - private static final Schema SCHEMA = Schema.of(ImmutableList.of( - Type.INT_TYPE, Type.INT_TYPE, Type.INT_TYPE, Type.STRING_TYPE), - ImmutableList.of("WorkerID", "StreamSize", "SampleSize", "SampleType")); + private static final Schema SCHEMA = Schema.ofFields("WorkerID", + Type.INT_TYPE, "StreamSize", Type.INT_TYPE, "SampleSize", Type.INT_TYPE, + "SampleType", Type.STRING_TYPE); /** Total number of tuples to sample. */ - private int sampleSize; + private int sampleSize = 0; /** True if using a percentage instead of a specific tuple count. */ private boolean isPercentageSample = false; @@ -41,21 +44,8 @@ public class SamplingDistribution extends UnaryOperator { /** Random generator used for creating the distribution. */ private Random rand; - /** Seed for the random generator. */ - private Long randomSeed; - - /** - * Distribution of the tuples across the workers. Value at index i == # of - * tuples on worker i. - */ - ArrayList tupleCounts; - - /** - * Distribution of the actual stream size across the workers. May be different - * from tupleCounts if workers pre-sampled the data. Value at index i == # of - * tuples in stream on worker i. - */ - ArrayList streamCounts; + /** Maps (worker_i) --> (sampling info for worker_i) */ + TreeMap workerInfo = new TreeMap<>(); /** Total number of tuples across all workers. */ int totalTupleCount = 0; @@ -64,7 +54,10 @@ private SamplingDistribution(Operator child, SamplingType sampleType, Long randomSeed) { super(child); this.sampleType = sampleType; - this.randomSeed = randomSeed; + rand = new Random(); + if (randomSeed != null) { + rand.setSeed(randomSeed); + } } /** @@ -83,9 +76,9 @@ private SamplingDistribution(Operator child, SamplingType sampleType, public SamplingDistribution(Operator child, int sampleSize, SamplingType sampleType, Long randomSeed) { this(child, sampleType, randomSeed); + Preconditions.checkArgument(sampleSize >= 0, + "Sample Size must be >= 0: %s", sampleSize); this.sampleSize = sampleSize; - Preconditions.checkState(this.sampleSize >= 0, - "Sample Size must be >= 0: %s", this.sampleSize); } /** @@ -107,13 +100,14 @@ public SamplingDistribution(Operator child, float samplePercentage, this(child, sampleType, randomSeed); this.isPercentageSample = true; this.samplePercentage = samplePercentage; - Preconditions.checkState(samplePercentage >= 0 && samplePercentage <= 100, + Preconditions.checkArgument(samplePercentage >= 0 + && samplePercentage <= 100, "Sample Percentage must be >= 0 && <= 100: %s", samplePercentage); } @Override protected TupleBatch fetchNextReady() throws DbException { - // Drain out all the workerID and partitionSize info. + // Drain out all the worker info. while (!getChild().eos()) { TupleBatch tb = getChild().nextReady(); if (tb == null) { @@ -122,57 +116,9 @@ protected TupleBatch fetchNextReady() throws DbException { } return null; } - Type col0Type = tb.getSchema().getColumnType(0); - Type col1Type = tb.getSchema().getColumnType(1); - boolean hasStreamSize = false; - Type col2Type = null; - if (tb.getSchema().numColumns() > 2) { - hasStreamSize = true; - col2Type = tb.getSchema().getColumnType(2); - } - for (int i = 0; i < tb.numTuples(); i++) { - int workerID; - if (col0Type == Type.INT_TYPE) { - workerID = tb.getInt(0, i); - } else if (col0Type == Type.LONG_TYPE) { - workerID = (int) tb.getLong(0, i); - } else { - throw new DbException("WorkerID must be of type INT or LONG"); - } - Preconditions.checkState(workerID > 0, "WorkerID must be > 0"); - // Ensure the future .set(workerID, -) calls will work. - for (int j = tupleCounts.size(); j < workerID; j++) { - tupleCounts.add(0); - streamCounts.add(0); - } - - int partitionSize; - if (col1Type == Type.INT_TYPE) { - partitionSize = tb.getInt(1, i); - } else if (col1Type == Type.LONG_TYPE) { - partitionSize = (int) tb.getLong(1, i); - } else { - throw new DbException("PartitionSize must be of type INT or LONG"); - } - Preconditions.checkState(partitionSize >= 0, - "Worker cannot have a negative PartitionSize: %s", partitionSize); - tupleCounts.set(workerID - 1, partitionSize); - totalTupleCount += partitionSize; - int streamSize = partitionSize; - if (hasStreamSize) { - if (col2Type == Type.INT_TYPE) { - streamSize = tb.getInt(2, i); - } else if (col2Type == Type.LONG_TYPE) { - streamSize = (int) tb.getLong(2, i); - } else { - throw new DbException("StreamSize must be of type INT or LONG"); - } - Preconditions.checkState(partitionSize >= 0, - "Worker cannot have a negative StreamSize: %d", streamSize); - } - streamCounts.set(workerID - 1, streamSize); - } + extractWorkerInfo(tb); } + // Convert samplePct to sampleSize if using a percentage sample. if (isPercentageSample) { sampleSize = Math.round(totalTupleCount * (samplePercentage / 100)); @@ -181,99 +127,151 @@ protected TupleBatch fetchNextReady() throws DbException { "Cannot extract %s samples from a population of size %s", sampleSize, totalTupleCount); - // Generate a random distribution across the workers. - int[] sampleCounts; - if (sampleType == SamplingType.WR) { - sampleCounts = withReplacementDistribution(tupleCounts, sampleSize); - } else if (sampleType == SamplingType.WoR) { - sampleCounts = withoutReplacementDistribution(tupleCounts, sampleSize); + // Generate a sampling distribution across the workers. + if (sampleType == SamplingType.WithReplacement) { + withReplacementDistribution(workerInfo, totalTupleCount, sampleSize); + } else if (sampleType == SamplingType.WithoutReplacement) { + withoutReplacementDistribution(workerInfo, totalTupleCount, sampleSize); } else { throw new DbException("Invalid sampleType: " + sampleType); } // Build and return a TupleBatch with the distribution. - IntColumnBuilder wIdCol = new IntColumnBuilder(); - IntColumnBuilder streamSizeCol = new IntColumnBuilder(); - IntColumnBuilder sampCountCol = new IntColumnBuilder(); - StringColumnBuilder sampTypeCol = new StringColumnBuilder(); - for (int i = 0; i < streamCounts.size(); i++) { - wIdCol.appendInt(i + 1); - streamSizeCol.appendInt(streamCounts.get(i)); - sampCountCol.appendInt(sampleCounts[i]); - sampTypeCol.appendString(sampleType.name()); + // Assumes that the sampling information can fit into one tuple batch. + IntColumnBuilder wIDs = new IntColumnBuilder(); + IntColumnBuilder actualSizes = new IntColumnBuilder(); + IntColumnBuilder sampSizes = new IntColumnBuilder(); + StringColumnBuilder sampTypes = new StringColumnBuilder(); + for (Map.Entry iWorker : workerInfo.entrySet()) { + wIDs.appendInt(iWorker.getKey()); + actualSizes.appendInt(iWorker.getValue().actualTupleCount); + sampSizes.appendInt(iWorker.getValue().sampleSize); + sampTypes.appendString(sampleType.name()); } ImmutableList.Builder> columns = ImmutableList.builder(); - columns.add(wIdCol.build(), streamSizeCol.build(), sampCountCol.build(), - sampTypeCol.build()); + columns.add(wIDs.build(), actualSizes.build(), sampSizes.build(), + sampTypes.build()); setEOS(); return new TupleBatch(SCHEMA, columns.build()); } + /** Helper function to extract worker information from a tuple batch. */ + private void extractWorkerInfo(TupleBatch tb) throws DbException { + Type col0Type = tb.getSchema().getColumnType(0); + Type col1Type = tb.getSchema().getColumnType(1); + boolean hasActualTupleCount = false; + Type col2Type = null; + if (tb.getSchema().numColumns() > 2) { + hasActualTupleCount = true; + col2Type = tb.getSchema().getColumnType(2); + } + + for (int i = 0; i < tb.numTuples(); i++) { + int workerID; + if (col0Type == Type.INT_TYPE) { + workerID = tb.getInt(0, i); + } else if (col0Type == Type.LONG_TYPE) { + workerID = (int) tb.getLong(0, i); + } else { + throw new DbException("WorkerID must be of type INT or LONG"); + } + Preconditions.checkState(workerID > 0, "WorkerID must be > 0"); + Preconditions.checkState(!workerInfo.containsKey(workerID), + "Duplicate WorkerIDs"); + + int tupleCount; + if (col1Type == Type.INT_TYPE) { + tupleCount = tb.getInt(1, i); + } else if (col1Type == Type.LONG_TYPE) { + tupleCount = (int) tb.getLong(1, i); + } else { + throw new DbException("TupleCount must be of type INT or LONG"); + } + Preconditions.checkState(tupleCount >= 0, + "Worker cannot have a negative TupleCount: %s", tupleCount); + + int actualTupleCount = tupleCount; + if (hasActualTupleCount) { + if (col2Type == Type.INT_TYPE) { + actualTupleCount = tb.getInt(2, i); + } else if (col2Type == Type.LONG_TYPE) { + actualTupleCount = (int) tb.getLong(2, i); + } else { + throw new DbException("ActualTupleCount must be of type INT or LONG"); + } + Preconditions.checkState(tupleCount >= 0, + "Worker cannot have a negative ActualTupleCount: %d", + actualTupleCount); + } + + WorkerInfo wInfo = new WorkerInfo(tupleCount, actualTupleCount); + workerInfo.put(workerID, wInfo); + totalTupleCount += tupleCount; + } + } + /** - * Creates a WithReplacement distribution across the workers. - * - * @param tupleCounts - * list of how many tuples each worker has. + * Creates a WithoutReplacement distribution across the workers. + * + * @param workerInfo + * reference to the workerInfo to modify. + * @param totalTupleCount + * total # of tuples across all workers. * @param sampleSize - * total number of samples to distribute across the workers. - * @return array representing the distribution across the workers. + * total # of samples to distribute across the workers. */ - private int[] withReplacementDistribution(List tupleCounts, + private void withReplacementDistribution( + TreeMap workerInfo, int totalTupleCount, int sampleSize) { - int[] distribution = new int[tupleCounts.size()]; - int totalTupleCount = 0; - for (int val : tupleCounts) { - totalTupleCount += val; - } - for (int i = 0; i < sampleSize; i++) { int sampleTupleIdx = rand.nextInt(totalTupleCount); // Assign this tuple to the workerID that holds this sampleTupleIdx. int tupleOffset = 0; - for (int j = 0; j < tupleCounts.size(); j++) { - if (sampleTupleIdx < tupleCounts.get(j) + tupleOffset) { - distribution[j] += 1; + for (Map.Entry iWorker : workerInfo.entrySet()) { + WorkerInfo wInfo = iWorker.getValue(); + if (sampleTupleIdx < wInfo.tupleCount + tupleOffset) { + wInfo.sampleSize += 1; break; } - tupleOffset += tupleCounts.get(j); + tupleOffset += wInfo.tupleCount; } } - return distribution; } /** * Creates a WithoutReplacement distribution across the workers. * - * @param tupleCounts - * list of how many tuples each worker has. + * @param workerInfo + * reference to the workerInfo to modify. + * @param totalTupleCount + * total # of tuples across all workers. * @param sampleSize - * total number of samples to distribute across the workers. - * @return array representing the distribution across the workers. + * total # of samples to distribute across the workers. */ - private int[] withoutReplacementDistribution(List tupleCounts, + private void withoutReplacementDistribution( + TreeMap workerInfo, int totalTupleCount, int sampleSize) { - int[] distribution = new int[tupleCounts.size()]; - int totalTupleCount = 0; - for (int val : tupleCounts) { - totalTupleCount += val; + Map logicalTupleCounts = new TreeMap<>(); + for (Map.Entry wInfo : workerInfo.entrySet()) { + logicalTupleCounts.put(wInfo.getKey(), wInfo.getValue().tupleCount); } - List logicalTupleCounts = new ArrayList<>(tupleCounts); for (int i = 0; i < sampleSize; i++) { int sampleTupleIdx = rand.nextInt(totalTupleCount - i); // Assign this tuple to the workerID that holds this sampleTupleIdx. int tupleOffset = 0; - for (int j = 0; j < logicalTupleCounts.size(); j++) { - if (sampleTupleIdx < logicalTupleCounts.get(j) + tupleOffset) { - distribution[j] += 1; + for (Map.Entry iWorker : workerInfo.entrySet()) { + int wID = iWorker.getKey(); + WorkerInfo wInfo = iWorker.getValue(); + if (sampleTupleIdx < logicalTupleCounts.get(wID) + tupleOffset) { + wInfo.sampleSize += 1; // Cannot sample the same tuple, so pretend it doesn't exist anymore. - logicalTupleCounts.set(j, logicalTupleCounts.get(j) - 1); + logicalTupleCounts.put(wID, logicalTupleCounts.get(wID) - 1); break; } - tupleOffset += logicalTupleCounts.get(j); + tupleOffset += logicalTupleCounts.get(wID); } } - return distribution; } /** @@ -284,11 +282,6 @@ public int getSampleSize() { return sampleSize; } - /** Returns whether this operator is using a percentage sample. */ - public boolean isPercentageSample() { - return isPercentageSample; - } - /** * Returns the percentage of total tuples that this operator will distribute. * Will be 0 if the operator was created using a specific sampleSize. @@ -308,13 +301,28 @@ public Schema generateSchema() { } @Override - protected void init(final ImmutableMap execEnvVars) { - rand = new Random(); - if (randomSeed != null) { - rand.setSeed(randomSeed); + public void cleanup() { + workerInfo = null; + } + + /** Encapsulates sampling information about a worker. */ + private class WorkerInfo { + /** # of tuples that this worker owns. */ + int tupleCount; + + /** + * Actual # of tuples that the worker has stored. May be different than + * tupleCount if the worker pre-sampled the data. + **/ + int actualTupleCount; + + /** # of tuples that the distribution assigned to this worker. */ + int sampleSize = 0; + + WorkerInfo(int tupleCount, int actualTupleCount) { + this.tupleCount = tupleCount; + this.actualTupleCount = actualTupleCount; } - tupleCounts = new ArrayList<>(); - streamCounts = new ArrayList<>(); } } diff --git a/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java b/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java index 4ceda2040..3cc90b7ef 100644 --- a/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java +++ b/src/edu/washington/escience/myria/operator/network/partition/IdentityHashPartitionFunction.java @@ -1,5 +1,9 @@ package edu.washington.escience.myria.operator.network.partition; +import java.util.Objects; + +import javax.annotation.Nonnull; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; @@ -9,7 +13,8 @@ /** * Implementation of a PartitionFunction that use the trivial identity hash. - * i.e. a --> a + * (i.e. a --> a) The attribute to hash on must be an INT column and should + * represent a workerID */ public final class IdentityHashPartitionFunction extends PartitionFunction { @@ -21,13 +26,14 @@ public final class IdentityHashPartitionFunction extends PartitionFunction { private final int index; /** - * @param index the index of the partition field. + * @param index + * the index of the partition field. */ @JsonCreator public IdentityHashPartitionFunction( @JsonProperty(value = "index", required = true) final Integer index) { super(null); - this.index = java.util.Objects.requireNonNull(index, "missing property index"); + this.index = Objects.requireNonNull(index, "missing property index"); Preconditions.checkArgument(this.index >= 0, "IdentityHash field index cannot take negative value %s", this.index); } @@ -40,12 +46,14 @@ public int getIndex() { } /** - * @param tb data. + * @param tb + * data. * @return partitions. * */ @Override - public int[] partition(final TupleBatch tb) { - Preconditions.checkArgument(tb.getSchema().getColumnType(index) == Type.INT_TYPE, + public int[] partition(@Nonnull final TupleBatch tb) { + Preconditions.checkArgument( + tb.getSchema().getColumnType(index) == Type.INT_TYPE, "IdentityHash index column must be of type INT"); final int[] result = new int[tb.numTuples()]; for (int i = 0; i < result.length; i++) { diff --git a/src/edu/washington/escience/myria/util/SamplingType.java b/src/edu/washington/escience/myria/util/SamplingType.java index 943a5619c..157c4337d 100644 --- a/src/edu/washington/escience/myria/util/SamplingType.java +++ b/src/edu/washington/escience/myria/util/SamplingType.java @@ -1,9 +1,21 @@ package edu.washington.escience.myria.util; +import com.fasterxml.jackson.annotation.JsonValue; + /** * Enumeration of supported sampling types. */ public enum SamplingType { - // WithReplacement, WithoutReplacement - WR, WoR + WithReplacement("WR"), WithoutReplacement("WoR"); + + private String shortName; + + SamplingType(String shortName) { + this.shortName = shortName; + } + + @JsonValue + public String getShortName() { + return shortName; + } } diff --git a/test/edu/washington/escience/myria/operator/SampleWRTest.java b/test/edu/washington/escience/myria/operator/SampleWRTest.java index f6bb3a454..a998431f7 100644 --- a/test/edu/washington/escience/myria/operator/SampleWRTest.java +++ b/test/edu/washington/escience/myria/operator/SampleWRTest.java @@ -123,7 +123,7 @@ private void verifyExpectedResults(int partitionSize, int sampleSize, int[] expected) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putString(3, "WR"); + leftInput.putString(3, "WithReplacement"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -145,7 +145,7 @@ private void drainOperator(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putString(3, "WR"); + leftInput.putString(3, "WithReplacement"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); while (!sampOp.eos()) { diff --git a/test/edu/washington/escience/myria/operator/SampleWoRTest.java b/test/edu/washington/escience/myria/operator/SampleWoRTest.java index 365a1978b..36712b149 100644 --- a/test/edu/washington/escience/myria/operator/SampleWoRTest.java +++ b/test/edu/washington/escience/myria/operator/SampleWoRTest.java @@ -111,7 +111,7 @@ public void cleanup() throws DbException { private void verifyExpectedResults(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putString(3, "WoR"); + leftInput.putString(3, "WithoutReplacement"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); @@ -131,7 +131,7 @@ private void drainOperator(int partitionSize, int sampleSize) throws DbException { leftInput.putInt(1, partitionSize); leftInput.putInt(2, sampleSize); - leftInput.putString(3, "WoR"); + leftInput.putString(3, "WithoutReplacement"); sampOp = new Sample(new TupleSource(leftInput), new TupleSource(rightInput), RANDOM_SEED); sampOp.open(TestEnvVars.get()); while (!sampOp.eos()) { diff --git a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java index 1c0584617..22de2fce6 100644 --- a/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java +++ b/test/edu/washington/escience/myria/operator/SamplingDistributionTest.java @@ -50,7 +50,7 @@ public void setup() { @Test public void testSampleWRSizeZero() throws DbException { int sampleSize = 0; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -59,7 +59,7 @@ public void testSampleWRSizeZero() throws DbException { @Test public void testSampleWoRSizeZero() throws DbException { int sampleSize = 0; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -69,7 +69,7 @@ public void testSampleWoRSizeZero() throws DbException { @Test public void testSampleWRPctZero() throws DbException { float samplePct = 0; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -78,7 +78,7 @@ public void testSampleWRPctZero() throws DbException { @Test public void testSampleWoRPctZero() throws DbException { float samplePct = 0; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; final int[][] expectedResults = { { 1, 300, 0 }, { 2, 200, 0 }, { 3, 400, 0 }, { 4, 100, 0 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -88,14 +88,14 @@ public void testSampleWoRPctZero() throws DbException { @Test public void testSampleWRSizeOne() throws DbException { int sampleSize = 1; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeOne() throws DbException { int sampleSize = 1; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @@ -103,14 +103,14 @@ public void testSampleWoRSizeOne() throws DbException { @Test public void testSampleWRSizeFifty() throws DbException { int sampleSize = 50; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWoRSizeFifty() throws DbException { int sampleSize = 50; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @@ -118,14 +118,14 @@ public void testSampleWoRSizeFifty() throws DbException { @Test public void testSampleWRPctFifty() throws DbException { float samplePct = 50; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(samplePct, sampleType); } @Test public void testSampleWoRPctFifty() throws DbException { float samplePct = 50; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; verifyPossibleDistribution(samplePct, sampleType); } @@ -133,14 +133,14 @@ public void testSampleWoRPctFifty() throws DbException { @Test public void testSampleWoRSizeAllButOne() throws DbException { int sampleSize = 999; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWRSizeAllButOne() throws DbException { int sampleSize = 999; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @@ -148,7 +148,7 @@ public void testSampleWRSizeAllButOne() throws DbException { @Test public void testSampleWoRSizeMax() throws DbException { int sampleSize = 1000; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, { 3, 400, 400 }, { 4, 100, 100 } }; verifyExpectedResults(sampleSize, sampleType, expectedResults); @@ -157,7 +157,7 @@ public void testSampleWoRSizeMax() throws DbException { @Test public void testSampleWoRPctMax() throws DbException { float samplePct = 100; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; final int[][] expectedResults = { { 1, 300, 300 }, { 2, 200, 200 }, { 3, 400, 400 }, { 4, 100, 100 } }; verifyExpectedResults(samplePct, sampleType, expectedResults); @@ -167,14 +167,14 @@ public void testSampleWoRPctMax() throws DbException { @Test public void testSampleWRSizeMax() throws DbException { int sampleSize = 1000; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(sampleSize, sampleType); } @Test public void testSampleWRPctMax() throws DbException { float samplePct = 100; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; verifyPossibleDistribution(samplePct, sampleType); } @@ -182,57 +182,57 @@ public void testSampleWRPctMax() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWoRSizeTooMany() throws DbException { int sampleSize = 1001; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(sampleSize, sampleType); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWoRPctTooMany() throws DbException { float samplePct = 100.1f; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(samplePct, sampleType); } @Test(expected = IllegalStateException.class) public void testSampleWRSizeTooMany() throws DbException { int sampleSize = 1001; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; drainOperator(sampleSize, sampleType); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWRPctTooMany() throws DbException { float samplePct = 100.1f; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; drainOperator(samplePct, sampleType); } /** Cannot sample a negative number of samples. */ - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWoRSizeNegative() throws DbException { int sampleSize = -1; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(sampleSize, sampleType); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWoRPctNegative() throws DbException { float samplePct = -0.01f; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(samplePct, sampleType); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWRSizeNegative() throws DbException { int sampleSize = -1; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(sampleSize, sampleType); } - @Test(expected = IllegalStateException.class) + @Test(expected = IllegalArgumentException.class) public void testSampleWRPctNegative() throws DbException { float samplePct = -0.01f; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; drainOperator(samplePct, sampleType); } @@ -240,7 +240,7 @@ public void testSampleWRPctNegative() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWoRWorkerNegative() throws DbException { int sampleSize = 50; - SamplingType sampleType = SamplingType.WoR; + SamplingType sampleType = SamplingType.WithoutReplacement; input.putInt(0, 5); input.putInt(1, -1); drainOperator(sampleSize, sampleType); @@ -249,7 +249,7 @@ public void testSampleWoRWorkerNegative() throws DbException { @Test(expected = IllegalStateException.class) public void testSampleWRWorkerNegative() throws DbException { int sampleSize = 50; - SamplingType sampleType = SamplingType.WR; + SamplingType sampleType = SamplingType.WithReplacement; input.putInt(0, 5); input.putInt(1, -1); drainOperator(sampleSize, sampleType); @@ -308,7 +308,7 @@ private void verifyPossibleDistribution(SamplingDistribution sampOp) assertEquals(expectedResultSchema, result.getSchema()); for (int i = 0; i < result.numTuples(); ++i, ++rowIdx) { assertTrue(result.getInt(2, i) >= 0 && result.getInt(2, i) <= sampOp.getSampleSize()); - if (sampOp.getSampleType() == SamplingType.WoR) { + if (sampOp.getSampleType() == SamplingType.WithoutReplacement) { // SampleWoR cannot sample more than worker's population size. assertTrue(result.getInt(2, i) <= result.getInt(1, i)); } From 1b66306aed4ba6071c2258c55bfb3ac6e42b9baf Mon Sep 17 00:00:00 2001 From: Yuqing Guo Date: Sat, 6 Jun 2015 23:17:46 -0700 Subject: [PATCH 25/29] quick fixes --- .../escience/myria/operator/agg/StreamingAggregate.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java index ec404c877..d89424d4f 100644 --- a/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java +++ b/src/edu/washington/escience/myria/operator/agg/StreamingAggregate.java @@ -89,7 +89,7 @@ public StreamingAggregate(@Nullable final Operator child, @Nonnull final int[] g @Override protected TupleBatch fetchNextReady() throws DbException { if (child.eos()) { - return resultBuffer.popAny(); + return null; } if (tb == null) { tb = child.nextReady(); @@ -170,7 +170,7 @@ protected Schema generateSchema() { } groupSchema = inputSchema.getSubSchema(gFields); - resultSchema = groupSchema; + resultSchema = Schema.of(groupSchema.getColumnTypes(), groupSchema.getColumnNames()); try { for (Aggregator agg : AggUtils.allocateAggs(factories, inputSchema)) { Schema curAggSchema = agg.getResultSchema(); From 48fd0d0f8478248839ab4548e296b6c7f671f3aa Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sun, 7 Jun 2015 00:30:49 -0700 Subject: [PATCH 26/29] minor edits to SamplingDistribution --- .../myria/operator/SamplingDistribution.java | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index ae837ea82..3283ca11c 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -1,8 +1,6 @@ package edu.washington.escience.myria.operator; -import java.util.Map; -import java.util.Random; -import java.util.TreeMap; +import java.util.*; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -11,8 +9,8 @@ import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.column.Column; -import edu.washington.escience.myria.column.builder.IntColumnBuilder; -import edu.washington.escience.myria.column.builder.StringColumnBuilder; +import edu.washington.escience.myria.column.builder.ColumnBuilder; +import edu.washington.escience.myria.column.builder.ColumnFactory; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.util.SamplingType; @@ -45,7 +43,7 @@ public class SamplingDistribution extends UnaryOperator { private Random rand; /** Maps (worker_i) --> (sampling info for worker_i) */ - TreeMap workerInfo = new TreeMap<>(); + SortedMap workerInfo = new TreeMap<>(); /** Total number of tuples across all workers. */ int totalTupleCount = 0; @@ -118,6 +116,7 @@ protected TupleBatch fetchNextReady() throws DbException { } extractWorkerInfo(tb); } + getChild().close(); // Convert samplePct to sampleSize if using a percentage sample. if (isPercentageSample) { @@ -138,19 +137,17 @@ protected TupleBatch fetchNextReady() throws DbException { // Build and return a TupleBatch with the distribution. // Assumes that the sampling information can fit into one tuple batch. - IntColumnBuilder wIDs = new IntColumnBuilder(); - IntColumnBuilder actualSizes = new IntColumnBuilder(); - IntColumnBuilder sampSizes = new IntColumnBuilder(); - StringColumnBuilder sampTypes = new StringColumnBuilder(); + List> colBuilders = ColumnFactory.allocateColumns(SCHEMA); for (Map.Entry iWorker : workerInfo.entrySet()) { - wIDs.appendInt(iWorker.getKey()); - actualSizes.appendInt(iWorker.getValue().actualTupleCount); - sampSizes.appendInt(iWorker.getValue().sampleSize); - sampTypes.appendString(sampleType.name()); + colBuilders.get(0).appendInt(iWorker.getKey()); + colBuilders.get(1).appendInt(iWorker.getValue().actualTupleCount); + colBuilders.get(2).appendInt(iWorker.getValue().sampleSize); + colBuilders.get(3).appendString(sampleType.name()); } ImmutableList.Builder> columns = ImmutableList.builder(); - columns.add(wIDs.build(), actualSizes.build(), sampSizes.build(), - sampTypes.build()); + for (ColumnBuilder cb : colBuilders) { + columns.add(cb.build()); + } setEOS(); return new TupleBatch(SCHEMA, columns.build()); } @@ -211,7 +208,7 @@ private void extractWorkerInfo(TupleBatch tb) throws DbException { } /** - * Creates a WithoutReplacement distribution across the workers. + * Creates a WithReplacement distribution across the workers. * * @param workerInfo * reference to the workerInfo to modify. @@ -221,7 +218,7 @@ private void extractWorkerInfo(TupleBatch tb) throws DbException { * total # of samples to distribute across the workers. */ private void withReplacementDistribution( - TreeMap workerInfo, int totalTupleCount, + SortedMap workerInfo, int totalTupleCount, int sampleSize) { for (int i = 0; i < sampleSize; i++) { int sampleTupleIdx = rand.nextInt(totalTupleCount); @@ -249,9 +246,9 @@ private void withReplacementDistribution( * total # of samples to distribute across the workers. */ private void withoutReplacementDistribution( - TreeMap workerInfo, int totalTupleCount, + SortedMap workerInfo, int totalTupleCount, int sampleSize) { - Map logicalTupleCounts = new TreeMap<>(); + SortedMap logicalTupleCounts = new TreeMap<>(); for (Map.Entry wInfo : workerInfo.entrySet()) { logicalTupleCounts.put(wInfo.getKey(), wInfo.getValue().tupleCount); } From 26dfea855028232c02f71f45fa2f59784de48e62 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sun, 7 Jun 2015 01:54:59 -0700 Subject: [PATCH 27/29] fixed SampledDbInsertTemp --- .../escience/myria/operator/DbInsertTemp.java | 58 ++++++++++--------- .../myria/operator/SampledDbInsertTemp.java | 9 ++- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/DbInsertTemp.java b/src/edu/washington/escience/myria/operator/DbInsertTemp.java index 4a57ca604..6bd702474 100644 --- a/src/edu/washington/escience/myria/operator/DbInsertTemp.java +++ b/src/edu/washington/escience/myria/operator/DbInsertTemp.java @@ -109,32 +109,9 @@ protected void consumeTuples(final TupleBatch tupleBatch) throws DbException { } @Override - protected void init(final ImmutableMap execEnvVars) throws DbException { - - /* retrieve connection information from the environment variables, if not already set */ - if (connectionInfo == null && execEnvVars != null) { - connectionInfo = (ConnectionInfo) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); - } - - if (connectionInfo == null) { - throw new DbException("Unable to instantiate DbInsertTemp: connection information unknown"); - } - - if (connectionInfo instanceof SQLiteInfo) { - /* Set WAL in the beginning. */ - final File dbFile = new File(((SQLiteInfo) connectionInfo).getDatabaseFilename()); - SQLiteConnection conn = new SQLiteConnection(dbFile); - try { - conn.open(true); - conn.exec("PRAGMA journal_mode=WAL;"); - } catch (SQLiteException e) { - e.printStackTrace(); - } - conn.dispose(); - } - - /* open the database connection */ - accessMethod = AccessMethod.of(connectionInfo.getDbms(), connectionInfo, false); + protected void init(final ImmutableMap execEnvVars) + throws DbException { + setupConnection(execEnvVars); if (overwriteTable) { stagingRelationKey = @@ -176,4 +153,33 @@ public Map writeSet() { return ImmutableMap.of(relationKey, new RelationWriteMetadata(relationKey, getSchema(), overwriteTable, true)); } + /** Updates connection information with the environment variables. */ + protected void setupConnection(final ImmutableMap execEnvVars) + throws DbException { + // Extract connection info from environment + if (connectionInfo == null && execEnvVars != null) { + connectionInfo = (ConnectionInfo) execEnvVars.get(MyriaConstants.EXEC_ENV_VAR_DATABASE_CONN_INFO); + } + + if (connectionInfo == null) { + throw new DbException("Unknown connection information."); + } + + if (connectionInfo instanceof SQLiteInfo) { + /* Set WAL in the beginning. */ + final File dbFile = new File(((SQLiteInfo) connectionInfo).getDatabaseFilename()); + SQLiteConnection conn = new SQLiteConnection(dbFile); + try { + conn.open(true); + conn.exec("PRAGMA journal_mode=WAL;"); + } catch (SQLiteException e) { + e.printStackTrace(); + } + conn.dispose(); + } + + // Open the database connection. + accessMethod = AccessMethod.of(connectionInfo.getDbms(), connectionInfo, false); + } + } diff --git a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java index d993a1c75..6daa1573b 100644 --- a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java +++ b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java @@ -68,7 +68,7 @@ public class SampledDbInsertTemp extends DbInsertTemp implements DbWriter { public SampledDbInsertTemp(final Operator child, final int sampleSize, final RelationKey sampleRelationKey, final RelationKey countRelationKey, final ConnectionInfo connectionInfo, Long randomSeed) { - super(child, sampleRelationKey, connectionInfo, true, null); + super(child, sampleRelationKey, connectionInfo, false, null); Preconditions.checkArgument(sampleSize >= 0, "sampleSize must be non-negative"); this.sampleSize = sampleSize; @@ -130,8 +130,11 @@ protected void childEOS() throws DbException { @Override protected void init(final ImmutableMap execEnvVars) throws DbException { - // Will set up the database connection and create the reservoir table. - super.init(execEnvVars); + setupConnection(execEnvVars); + + // Set up the reservoir table. + accessMethod.dropTableIfExists(getRelationKey()); + accessMethod.createTableIfNotExists(getRelationKey(), getSchema()); // Set up the tuple count table. accessMethod.dropTableIfExists(countRelationKey); From 8687fa3b3b42c7880a3caeb4e62358f1185c2b49 Mon Sep 17 00:00:00 2001 From: Dan Radion Date: Sun, 7 Jun 2015 11:43:54 -0700 Subject: [PATCH 28/29] fix up imports and removed redundant implements --- .../escience/myria/operator/SampledDbInsertTemp.java | 2 +- .../escience/myria/operator/SamplingDistribution.java | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java index 6daa1573b..3ba76d859 100644 --- a/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java +++ b/src/edu/washington/escience/myria/operator/SampledDbInsertTemp.java @@ -25,7 +25,7 @@ /** * Samples the stream into a temp relation. */ -public class SampledDbInsertTemp extends DbInsertTemp implements DbWriter { +public class SampledDbInsertTemp extends DbInsertTemp { /** Required for Java serialization. */ private static final long serialVersionUID = 1L; diff --git a/src/edu/washington/escience/myria/operator/SamplingDistribution.java b/src/edu/washington/escience/myria/operator/SamplingDistribution.java index 3283ca11c..d2a4787d7 100644 --- a/src/edu/washington/escience/myria/operator/SamplingDistribution.java +++ b/src/edu/washington/escience/myria/operator/SamplingDistribution.java @@ -1,6 +1,10 @@ package edu.washington.escience.myria.operator; -import java.util.*; +import java.util.List; +import java.util.Random; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.Map; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; From 278316b82e336ce9313811cd7da87bf83d6c2401 Mon Sep 17 00:00:00 2001 From: Dominik Moritz Date: Tue, 9 Jun 2015 18:16:37 -0700 Subject: [PATCH 29/29] Address comments from review and simplify code. --- .../accessmethod/SQLiteAccessMethod.java | 2 +- .../SQLiteTupleBatchIterator.java | 18 +++------- .../coordinator/catalog/MasterCatalog.java | 6 ++-- .../myria/operator/CatalogQueryScan.java | 21 +++-------- .../escience/myria/operator/DbQueryScan.java | 35 +++++-------------- .../myria/operator/CatalogScanTest.java | 4 +-- 6 files changed, 22 insertions(+), 64 deletions(-) diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java b/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java index d2fd64741..f9f877166 100644 --- a/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteAccessMethod.java @@ -228,7 +228,7 @@ public Iterator tupleBatchIteratorFromQuery(final String queryString throw new DbException(e); } - return new SQLiteTupleBatchIterator(statement, schema, sqliteConnection); + return new SQLiteTupleBatchIterator(statement, sqliteConnection, schema); } @Override diff --git a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java index 24a6fe1b2..ac71652b5 100644 --- a/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java +++ b/src/edu/washington/escience/myria/accessmethod/SQLiteTupleBatchIterator.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,19 +35,6 @@ public class SQLiteTupleBatchIterator implements Iterator { /** The Schema of the TupleBatches returned by this Iterator. */ private final Schema schema; - /** - * Wraps a SQLiteStatement result set in an Iterator. - * - * @param statement the SQLiteStatement containing the results. - * @param schema the Schema describing the format of the TupleBatch containing these results. - * @param connection the connection to the SQLite database. - */ - SQLiteTupleBatchIterator(final SQLiteStatement statement, final Schema schema, final SQLiteConnection connection) { - this.statement = statement; - this.connection = connection; - this.schema = schema; - } - /** * Wraps a SQLiteStatement result set in an Iterator. * @@ -81,6 +69,10 @@ public boolean hasNext() { @Override public TupleBatch next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + /* Allocate TupleBatch parameters */ final int numFields = schema.numColumns(); final List> columnBuilders = ColumnFactory.allocateColumns(schema); diff --git a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java index 84d8ab8f1..99cee1fd8 100644 --- a/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java +++ b/src/edu/washington/escience/myria/coordinator/catalog/MasterCatalog.java @@ -33,6 +33,7 @@ import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import edu.washington.escience.myria.MyriaConstants.FTMode; @@ -2002,10 +2003,7 @@ protected Iterator job(final SQLiteConnection sqliteConnection) thro SQLiteException { SQLiteStatement statement = sqliteConnection.prepare(queryString); List tuples = Lists.newLinkedList(); - Iterator iter = new SQLiteTupleBatchIterator(statement, sqliteConnection, outputSchema); - while (iter.hasNext()) { - tuples.add(iter.next()); - } + Iterators.addAll(tuples, new SQLiteTupleBatchIterator(statement, sqliteConnection, outputSchema)); return tuples.iterator(); } }).get(); diff --git a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java index cc3054d2b..41fd441b3 100644 --- a/src/edu/washington/escience/myria/operator/CatalogQueryScan.java +++ b/src/edu/washington/escience/myria/operator/CatalogQueryScan.java @@ -3,8 +3,6 @@ import java.util.Iterator; import java.util.Objects; -import com.google.common.collect.ImmutableMap; - import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.coordinator.catalog.CatalogException; @@ -12,8 +10,8 @@ import edu.washington.escience.myria.storage.TupleBatch; /** - * Push a select query down into a JDBC based database and scan over the query result. - * */ + * Operator to get the result of a query on the catalog. The catalog is a SQLite database. + */ public class CatalogQueryScan extends LeafOperator { /** @@ -50,14 +48,9 @@ public class CatalogQueryScan extends LeafOperator { * @param catalog see the corresponding field. * */ public CatalogQueryScan(final String sql, final Schema outputSchema, final MasterCatalog catalog) { - Objects.requireNonNull(sql); - Objects.requireNonNull(outputSchema); - Objects.requireNonNull(catalog); - - this.sql = sql; - this.outputSchema = outputSchema; - this.catalog = catalog; - tuples = null; + this.sql = Objects.requireNonNull(sql);; + this.outputSchema = Objects.requireNonNull(outputSchema); + this.catalog = Objects.requireNonNull(catalog); } @Override @@ -87,8 +80,4 @@ protected final TupleBatch fetchNextReady() throws DbException { public final Schema generateSchema() { return outputSchema; } - - @Override - protected final void init(final ImmutableMap execEnvVars) throws DbException { - } } diff --git a/src/edu/washington/escience/myria/operator/DbQueryScan.java b/src/edu/washington/escience/myria/operator/DbQueryScan.java index 0ddb932c1..4aafa858d 100644 --- a/src/edu/washington/escience/myria/operator/DbQueryScan.java +++ b/src/edu/washington/escience/myria/operator/DbQueryScan.java @@ -68,13 +68,8 @@ public class DbQueryScan extends LeafOperator implements DbReader { * @param outputSchema see the corresponding field. * */ public DbQueryScan(final String baseSQL, final Schema outputSchema) { - Objects.requireNonNull(baseSQL); - Objects.requireNonNull(outputSchema); - - this.baseSQL = baseSQL; - this.outputSchema = outputSchema; - connectionInfo = null; - tuples = null; + this.baseSQL = Objects.requireNonNull(baseSQL); + this.outputSchema = Objects.requireNonNull(outputSchema); sortedColumns = null; ascending = null; } @@ -88,8 +83,7 @@ public DbQueryScan(final String baseSQL, final Schema outputSchema) { * */ public DbQueryScan(final ConnectionInfo connectionInfo, final String baseSQL, final Schema outputSchema) { this(baseSQL, outputSchema); - Objects.requireNonNull(connectionInfo); - this.connectionInfo = connectionInfo; + this.connectionInfo = Objects.requireNonNull(connectionInfo); } /** @@ -99,14 +93,8 @@ public DbQueryScan(final ConnectionInfo connectionInfo, final String baseSQL, fi * @param outputSchema the Schema of the returned tuples. */ public DbQueryScan(final RelationKey relationKey, final Schema outputSchema) { - Objects.requireNonNull(relationKey); - Objects.requireNonNull(outputSchema); - - this.relationKey = relationKey; - this.outputSchema = outputSchema; - baseSQL = null; - connectionInfo = null; - tuples = null; + this.relationKey = Objects.requireNonNull(relationKey); + this.outputSchema = Objects.requireNonNull(outputSchema); sortedColumns = null; ascending = null; } @@ -121,8 +109,7 @@ public DbQueryScan(final RelationKey relationKey, final Schema outputSchema) { */ public DbQueryScan(final ConnectionInfo connectionInfo, final RelationKey relationKey, final Schema outputSchema) { this(relationKey, outputSchema); - Objects.requireNonNull(connectionInfo); - this.connectionInfo = connectionInfo; + this.connectionInfo = Objects.requireNonNull(connectionInfo); } /** @@ -135,16 +122,10 @@ public DbQueryScan(final ConnectionInfo connectionInfo, final RelationKey relati */ public DbQueryScan(final RelationKey relationKey, final Schema outputSchema, final int[] sortedColumns, final boolean[] ascending) { - Objects.requireNonNull(relationKey); - Objects.requireNonNull(outputSchema); - - this.relationKey = relationKey; - this.outputSchema = outputSchema; + this.relationKey = Objects.requireNonNull(relationKey); + this.outputSchema = Objects.requireNonNull(outputSchema); this.sortedColumns = sortedColumns; this.ascending = ascending; - baseSQL = null; - connectionInfo = null; - tuples = null; } /** diff --git a/test/edu/washington/escience/myria/operator/CatalogScanTest.java b/test/edu/washington/escience/myria/operator/CatalogScanTest.java index 1d8262e67..bfd9f8492 100644 --- a/test/edu/washington/escience/myria/operator/CatalogScanTest.java +++ b/test/edu/washington/escience/myria/operator/CatalogScanTest.java @@ -12,8 +12,6 @@ import org.junit.Before; import org.junit.Test; -import com.google.common.collect.ImmutableList; - import edu.washington.escience.myria.DbException; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; @@ -59,7 +57,7 @@ public void Cleanup() { @Test public final void testQueryQueries() throws DbException, CatalogException { - Schema schema = new Schema(ImmutableList.of(Type.LONG_TYPE, Type.STRING_TYPE), ImmutableList.of("id", "raw")); + Schema schema = Schema.ofFields(Type.LONG_TYPE, "id", Type.STRING_TYPE, "raw"); CatalogQueryScan scan = new CatalogQueryScan("select query_id, raw_query from queries", schema, catalog); scan.open(null);