diff --git a/src/main/java/com/facebook/presto/AggregationUtil.java b/src/main/java/com/facebook/presto/AggregationUtil.java new file mode 100644 index 000000000000..823aa23d89e9 --- /dev/null +++ b/src/main/java/com/facebook/presto/AggregationUtil.java @@ -0,0 +1,21 @@ +package com.facebook.presto; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Range; + +public class AggregationUtil { + public static void processGroup(SeekableIterator aggregationSource, AggregationFunction aggregation, Range positions) + { + RangePositionBlock positionBlock = new RangePositionBlock(positions); + + // goto start of range + aggregationSource.seekTo(positions.lowerEndpoint()); + Preconditions.checkState(aggregationSource.hasNext(), "Group start position not found in aggregation source"); + + // while we have data... + while (aggregationSource.hasNext() && aggregationSource.peek().getRange().isConnected(positions)) { + // process aggregation + aggregation.add(aggregationSource.next(), positionBlock); + } + } +} diff --git a/src/main/java/com/facebook/presto/CsvFileScanner.java b/src/main/java/com/facebook/presto/CsvFileScanner.java new file mode 100644 index 000000000000..39e6d3d8338b --- /dev/null +++ b/src/main/java/com/facebook/presto/CsvFileScanner.java @@ -0,0 +1,74 @@ +package com.facebook.presto; + +import com.google.common.base.Preconditions; +import com.google.common.base.Splitter; +import com.google.common.base.Throwables; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; +import com.google.common.io.InputSupplier; +import com.google.common.io.LineReader; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Iterator; + +public class CsvFileScanner implements Iterable +{ + private final InputSupplier inputSupplier; + private final Splitter columnSplitter; + private final int columnIndex; + + public CsvFileScanner(InputSupplier inputSupplier, int columnIndex, char columnSeparator) + { + this.columnIndex = columnIndex; + Preconditions.checkNotNull(inputSupplier, "inputSupplier is null"); + this.inputSupplier = inputSupplier; + columnSplitter = Splitter.on(columnSeparator); + } + + @Override + public Iterator iterator() + { + return new ColumnIterator(inputSupplier, columnIndex, columnSplitter); + } + + private static class ColumnIterator extends AbstractIterator + { + private long position; + private final LineReader reader; + private int columnIndex; + private Splitter columnSplitter; + + public ColumnIterator(InputSupplier inputSupplier, int columnIndex, Splitter columnSplitter) + { + try { + this.reader = new LineReader(inputSupplier.getInput()); + } + catch (IOException e) { + throw Throwables.propagate(e); + } + this.columnIndex = columnIndex; + this.columnSplitter = columnSplitter; + } + + @Override + protected ValueBlock computeNext() + { + String line; + try { + line = reader.readLine(); + } + catch (IOException e) { + throw Throwables.propagate(e); + } + if (line == null) { + endOfData(); + return null; + } + Iterable split = columnSplitter.split(line); + String value = Iterables.get(split, columnIndex); + return new UncompressedValueBlock(position++, value); + + } + } +} diff --git a/src/main/java/com/facebook/presto/GroupBy.java b/src/main/java/com/facebook/presto/GroupBy.java new file mode 100644 index 000000000000..60ab0abe41a3 --- /dev/null +++ b/src/main/java/com/facebook/presto/GroupBy.java @@ -0,0 +1,82 @@ +package com.facebook.presto; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.PeekingIterator; +import com.google.common.collect.Range; +import com.google.common.collect.Ranges; + +import java.util.Iterator; + +/** + * Group input data and produce a single block for each sequence of identical values. + */ +public class GroupBy + extends AbstractIterator +{ + private final Iterator groupBySource; + + private PeekingIterator currentGroupByBlock; + + public GroupBy(Iterator keySource) + { + this.groupBySource = keySource; + } + + @Override + protected RunLengthEncodedBlock computeNext() + { + // if no more data, return null + if (!advanceGroupByBlock()) { + endOfData(); + return null; + } + + // form a group from the current position, until the value changes + Pair entry = currentGroupByBlock.next(); + Object groupByKey = entry.getValue(); + long startPosition = entry.getPosition(); + + while (true) { + // skip entries until the current key changes or we've consumed this block + while (currentGroupByBlock.hasNext() && currentGroupByBlock.peek().getValue().equals(groupByKey)) { + entry = currentGroupByBlock.next(); + } + + // stop if there is more data in the current block since the next entry will be for a new group + if (currentGroupByBlock.hasNext()) { + break; + } + + // stop if we are at the end of the stream + if (!groupBySource.hasNext()) { + break; + } + + // process the next block + currentGroupByBlock = groupBySource.next().pairIterator(); + } + + long endPosition = entry.getPosition(); + Range range = Ranges.closed(startPosition, endPosition); + + RunLengthEncodedBlock group = new RunLengthEncodedBlock(groupByKey, range); + return group; + } + + private boolean advanceGroupByBlock() + { + // does current block iterator have more data? + if (currentGroupByBlock != null && currentGroupByBlock.hasNext()) { + return true; + } + + // are there more blocks? + if (!groupBySource.hasNext()) { + return false; + } + + // advance to next block and open an iterator + currentGroupByBlock = groupBySource.next().pairIterator(); + return true; + } +} diff --git a/src/main/java/com/facebook/presto/HashAggregation.java b/src/main/java/com/facebook/presto/HashAggregation.java new file mode 100644 index 000000000000..715a5a62c783 --- /dev/null +++ b/src/main/java/com/facebook/presto/HashAggregation.java @@ -0,0 +1,68 @@ +package com.facebook.presto; + +import com.google.common.collect.AbstractIterator; + +import javax.inject.Provider; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; + +import static com.facebook.presto.AggregationUtil.processGroup; + +public class HashAggregation + extends AbstractIterator +{ + private final Iterator groupBySource; + private final SeekableIterator aggregationSource; + + private final Provider functionProvider; + + private Iterator> aggregations; + + private long position; + + public HashAggregation(Iterator keySource, SeekableIterator valueSource, Provider functionProvider) + { + this.groupBySource = keySource; + this.aggregationSource = valueSource; + + this.functionProvider = functionProvider; + } + + @Override + protected ValueBlock computeNext() + { + // process all data ahead of time + if (aggregations == null) { + Map aggregationMap = new HashMap<>(); + while (groupBySource.hasNext()) { + RunLengthEncodedBlock group = groupBySource.next(); + + AggregationFunction aggregation = aggregationMap.get(group.getValue()); + if (aggregation == null) { + aggregation = functionProvider.get(); + aggregationMap.put(group.getValue(), aggregation); + } + processGroup(aggregationSource, aggregation, group.getRange()); + } + + this.aggregations = aggregationMap.entrySet().iterator(); + } + + // if no more data, return null + if (!aggregations.hasNext()) { + endOfData(); + return null; + } + + // get next aggregation + Entry aggregation = aggregations.next(); + + // calculate final value for this group + Object value = aggregation.getValue().evaluate(); + + // build an output block + return new UncompressedValueBlock(position++, new Tuple(aggregation.getKey(), value)); + } +} diff --git a/src/main/java/com/facebook/presto/Pair.java b/src/main/java/com/facebook/presto/Pair.java index e4e59496d41a..b397fe796ef0 100644 --- a/src/main/java/com/facebook/presto/Pair.java +++ b/src/main/java/com/facebook/presto/Pair.java @@ -47,6 +47,35 @@ public Object apply(Pair input) }; } + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Pair pair = (Pair) o; + + if (position != pair.position) { + return false; + } + if (value != null ? !value.equals(pair.value) : pair.value != null) { + return false; + } + + return true; + } + + @Override + public int hashCode() + { + int result = (int) (position ^ (position >>> 32)); + result = 31 * result + (value != null ? value.hashCode() : 0); + return result; + } @Override public String toString() diff --git a/src/main/java/com/facebook/presto/PairsIterator.java b/src/main/java/com/facebook/presto/PairsIterator.java new file mode 100644 index 000000000000..923aed66fe90 --- /dev/null +++ b/src/main/java/com/facebook/presto/PairsIterator.java @@ -0,0 +1,46 @@ +package com.facebook.presto; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.PeekingIterator; + +import java.util.Iterator; + +public class PairsIterator + extends AbstractIterator +{ + private final Iterator blockIterator; + private PeekingIterator currentBlock; + + public PairsIterator(Iterator blockIterator) + { + this.blockIterator = blockIterator; + } + + @Override + protected Pair computeNext() + { + if (!advance()) { + endOfData(); + return null; + } + return currentBlock.next(); + } + + private boolean advance() + { + // does current block iterator have more data? + if (currentBlock != null && currentBlock.hasNext()) { + return true; + } + + // are there more blocks? + if (!blockIterator.hasNext()) { + return false; + } + + // advance to next block and open an iterator + currentBlock = blockIterator.next().pairIterator(); + return true; + } + +} diff --git a/src/main/java/com/facebook/presto/PipelinedAggregation.java b/src/main/java/com/facebook/presto/PipelinedAggregation.java index 417346bb0ffb..4519da902b43 100644 --- a/src/main/java/com/facebook/presto/PipelinedAggregation.java +++ b/src/main/java/com/facebook/presto/PipelinedAggregation.java @@ -1,10 +1,6 @@ package com.facebook.presto; -import com.google.common.base.Preconditions; import com.google.common.collect.AbstractIterator; -import com.google.common.collect.PeekingIterator; -import com.google.common.collect.Range; -import com.google.common.collect.Ranges; import javax.inject.Provider; import java.util.Iterator; @@ -12,14 +8,12 @@ public class PipelinedAggregation extends AbstractIterator { - private final Iterator groupBySource; + private final Iterator groupBySource; private final SeekableIterator aggregationSource; private final Provider functionProvider; - private PeekingIterator currentGroupByBlock; - - public PipelinedAggregation(Iterator keySource, SeekableIterator valueSource, Provider functionProvider) + public PipelinedAggregation(Iterator keySource, SeekableIterator valueSource, Provider functionProvider) { this.groupBySource = keySource; this.aggregationSource = valueSource; @@ -31,76 +25,23 @@ public PipelinedAggregation(Iterator keySource, SeekableIterator positions) - { + // create a new aggregate for this group AggregationFunction aggregationFunction = functionProvider.get(); - RangePositionBlock positionBlock = new RangePositionBlock(positions); - // goto start of range - aggregationSource.seekTo(positions.lowerEndpoint()); - Preconditions.checkState(aggregationSource.hasNext(), "Group start position not found in aggregation source"); - - // while we have data... - while (aggregationSource.hasNext() && aggregationSource.peek().getRange().isConnected(positions)) { - // process aggregation - aggregationFunction.add(aggregationSource.next(), positionBlock); - } + AggregationUtil.processGroup(aggregationSource, aggregationFunction, group.getRange()); // calculate final value for this group Object value = aggregationFunction.evaluate(); // build an output block - return new UncompressedValueBlock(positions.lowerEndpoint(), value); - } - - private boolean advanceGroupByBlock() - { - // does current block iterator have more data? - if (currentGroupByBlock != null && currentGroupByBlock.hasNext()) { - return true; - } - - // are there more blocks? - if (!groupBySource.hasNext()) { - return false; - } - - // advance to next block and open an iterator - currentGroupByBlock = groupBySource.next().pairIterator(); - return true; + return new UncompressedValueBlock(group.getRange().lowerEndpoint(), new Tuple(group.getValue(), value)); } } diff --git a/src/main/java/com/facebook/presto/RunLengthEncodedBlock.java b/src/main/java/com/facebook/presto/RunLengthEncodedBlock.java new file mode 100644 index 000000000000..3bcd8a915900 --- /dev/null +++ b/src/main/java/com/facebook/presto/RunLengthEncodedBlock.java @@ -0,0 +1,125 @@ +package com.facebook.presto; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.DiscreteDomains; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; +import com.google.common.collect.Range; + +import java.util.Collections; +import java.util.Iterator; + +public class RunLengthEncodedBlock + implements ValueBlock +{ + private final Object value; + private final Range range; + + public RunLengthEncodedBlock(Object value, Range range) + { + this.value = value; + this.range = range; + } + + public Object getValue() + { + return value; + } + + @Override + public PositionBlock selectPositions(Predicate predicate) + { + return null; + } + + @Override + public ValueBlock selectPairs(Predicate predicate) + { + return null; + } + + @Override + public ValueBlock filter(PositionBlock positions) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (Long position : positions.getPositions()) { + if (range.contains(position)) { + builder.add(new Pair(position, value)); + } + } + + ImmutableList pairs = builder.build(); + if (pairs.isEmpty()) { + return new EmptyValueBlock(); + } + + return new UncompressedValueBlock(pairs); + } + + @Override + public PeekingIterator pairIterator() + { + return Iterators.peekingIterator(Iterators.transform(getPositions().iterator(), new Function() { + @Override + public Pair apply(Long position) + { + return new Pair(position, value); + } + })); + } + + @Override + public boolean isEmpty() + { + return false; + } + + @Override + public int getCount() + { + return (int) (range.upperEndpoint() - range.lowerEndpoint() + 1); + } + + @Override + public boolean isSorted() + { + return false; + } + + @Override + public boolean isSingleValue() + { + return true; + } + + @Override + public boolean isPositionsContiguous() + { + return true; + } + + @Override + public Iterable getPositions() + { + return range.asSet(DiscreteDomains.longs()); + } + + @Override + public Range getRange() + { + return range; + } + + public String toString() + { + return Iterators.toString(pairIterator()); + } + + @Override + public Iterator iterator() + { + return Iterators.peekingIterator(Collections.nCopies(getCount(), value).iterator()); + } +} diff --git a/src/main/java/com/facebook/presto/Tuple.java b/src/main/java/com/facebook/presto/Tuple.java index 284b7d8de1e8..88693cbf139a 100644 --- a/src/main/java/com/facebook/presto/Tuple.java +++ b/src/main/java/com/facebook/presto/Tuple.java @@ -9,6 +9,10 @@ public class Tuple { private final List values; + public Tuple(Object... values) { + this(ImmutableList.copyOf(values)); + } + public Tuple(List values) { Preconditions.checkNotNull(values, "values is null"); @@ -20,6 +24,31 @@ public List getValues() return values; } + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Tuple tuple = (Tuple) o; + + if (!values.equals(tuple.values)) { + return false; + } + + return true; + } + + @Override + public int hashCode() + { + return values.hashCode(); + } + @Override public String toString() { diff --git a/src/test/java/com/facebook/presto/CsvFileScannerTest.java b/src/test/java/com/facebook/presto/CsvFileScannerTest.java new file mode 100644 index 000000000000..82edce087ba4 --- /dev/null +++ b/src/test/java/com/facebook/presto/CsvFileScannerTest.java @@ -0,0 +1,47 @@ +package com.facebook.presto; + +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableList; +import com.google.common.io.InputSupplier; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.InputStreamReader; + +import static com.google.common.io.Resources.getResource; +import static com.google.common.io.Resources.newReaderSupplier; + +public class CsvFileScannerTest +{ + private final InputSupplier inputSupplier = newReaderSupplier(getResource("data.csv"), Charsets.UTF_8); + + @Test + public void testIterator() + throws Exception + { + CsvFileScanner firstColumn = new CsvFileScanner(inputSupplier, 0, ','); + + Assert.assertEquals(ImmutableList.copyOf(new PairsIterator(firstColumn.iterator())), + ImmutableList.of( + new Pair(0, "0"), + new Pair(1, "1"), + new Pair(2, "2"), + new Pair(3, "3"))); + + CsvFileScanner secondColumn = new CsvFileScanner(inputSupplier, 1, ','); + Assert.assertEquals(ImmutableList.copyOf(new PairsIterator(secondColumn.iterator())), + ImmutableList.of( + new Pair(0, "apple"), + new Pair(1, "banana"), + new Pair(2, "cherry"), + new Pair(3, "date"))); + + CsvFileScanner thirdColumn = new CsvFileScanner(inputSupplier, 2, ','); + Assert.assertEquals(ImmutableList.copyOf(new PairsIterator(thirdColumn.iterator())), + ImmutableList.of( + new Pair(0, "alice"), + new Pair(1, "bob"), + new Pair(2, "charlie"), + new Pair(3, "dave"))); + } +} diff --git a/src/test/java/com/facebook/presto/TestSumAggregation.java b/src/test/java/com/facebook/presto/TestSumAggregation.java index e7b55497db01..13d615919231 100644 --- a/src/test/java/com/facebook/presto/TestSumAggregation.java +++ b/src/test/java/com/facebook/presto/TestSumAggregation.java @@ -1,18 +1,25 @@ package com.facebook.presto; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; +import org.testng.Assert; import org.testng.annotations.Test; import javax.inject.Provider; +import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; +import java.util.List; +import java.util.Map; public class TestSumAggregation { @Test - public void test() + public void testPipelinedAggregation() { - PipelinedAggregation aggregation = new PipelinedAggregation(newGroupColumn(), new ForwardingSeekableIterator<>(newAggregateColumn()), new Provider() + GroupBy groupBy = new GroupBy(newGroupColumn()); + PipelinedAggregation aggregation = new PipelinedAggregation(groupBy, new ForwardingSeekableIterator<>(newAggregateColumn()), new Provider() { @Override public AggregationFunction get() @@ -21,15 +28,58 @@ public AggregationFunction get() } }); -// DataScan3 materialize = new DataScan3(newGroupColumn(), ) + List expected = ImmutableList.of( + new Pair(0, new Tuple("a", 10L)), + new Pair(4, new Tuple("b", 17L)), + new Pair(23, new Tuple("c", 15L)), + new Pair(30, new Tuple("d", 6L)) + ); + + List actual = new ArrayList<>(); while (aggregation.hasNext()) { ValueBlock block = aggregation.next(); PeekingIterator pairs = block.pairIterator(); while (pairs.hasNext()) { - System.out.println(pairs.next()); + Pair pair = pairs.next(); + actual.add(pair); } - System.out.println(); } + + Assert.assertEquals(actual, expected); + } + + @Test + public void testHashAggregation() + { + GroupBy groupBy = new GroupBy(newGroupColumn()); + HashAggregation aggregation = new HashAggregation(groupBy, new ForwardingSeekableIterator<>(newAggregateColumn()), new Provider() + { + @Override + public AggregationFunction get() + { + return new SumAggregation(); + } + }); + + Map expected = ImmutableMap.of( + "a", new Tuple("a", 10L), + "b", new Tuple("b", 17L), + "c", new Tuple("c", 15L), + "d", new Tuple("d", 6L) + ); + + Map actual = new HashMap<>(); + while (aggregation.hasNext()) { + ValueBlock block = aggregation.next(); + PeekingIterator pairs = block.pairIterator(); + while (pairs.hasNext()) { + Pair pair = pairs.next(); + Tuple tuple = (Tuple) pair.getValue(); + actual.put(tuple.getValues().get(0), tuple); + } + } + + Assert.assertEquals(actual, expected); } public Iterator newGroupColumn() @@ -59,5 +109,4 @@ public Iterator newAggregateColumn() return values; } - } diff --git a/src/test/resources/data.csv b/src/test/resources/data.csv new file mode 100644 index 000000000000..b7f8171de26b --- /dev/null +++ b/src/test/resources/data.csv @@ -0,0 +1,4 @@ +0,apple,alice +1,banana,bob +2,cherry,charlie +3,date,dave