Skip to content

Commit

Permalink
Run grouped aggregation using multiple threads
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 4, 2015
1 parent d0fa12e commit dc58f4f
Show file tree
Hide file tree
Showing 27 changed files with 705 additions and 66 deletions.
Expand Up @@ -107,6 +107,7 @@ protected List<? extends OperatorFactory> createOperatorFactories()
COUNT.bind(ImmutableList.of(2), Optional.empty(), Optional.empty(), 1.0)
),
Optional.empty(),
Optional.empty(),
10_000,
new DataSize(16, MEGABYTE));

Expand Down
Expand Up @@ -48,6 +48,7 @@ protected List<? extends OperatorFactory> createOperatorFactories()
Step.SINGLE,
ImmutableList.of(DOUBLE_SUM.bind(ImmutableList.of(1), Optional.empty(), Optional.empty(), 1.0)),
Optional.empty(),
Optional.empty(),
100_000,
new DataSize(16, MEGABYTE));
return ImmutableList.of(tableScanOperator, aggregationOperator);
Expand Down
Expand Up @@ -23,6 +23,7 @@ public final class SystemSessionProperties
private static final String TASK_WRITER_COUNT = "task_writer_count";
private static final String TASK_DEFAULT_CONCURRENCY = "task_default_concurrency";
private static final String TASK_JOIN_CONCURRENCY = "task_join_concurrency";
private static final String TASK_AGGREGATION_CONCURRENCY = "task_aggregation_concurrency";

private SystemSessionProperties() {}

Expand Down Expand Up @@ -89,4 +90,9 @@ public static int getTaskJoinConcurrency(Session session, int defaultValue)
{
return getNumber(TASK_JOIN_CONCURRENCY, session, getTaskDefaultConcurrency(session, defaultValue));
}

public static int getTaskAggregationConcurrency(Session session, int defaultValue)
{
return getNumber(TASK_AGGREGATION_CONCURRENCY, session, getTaskDefaultConcurrency(session, defaultValue));
}
}
Expand Up @@ -24,10 +24,13 @@
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
import static it.unimi.dsi.fastutil.HashCommon.murmurHash3;

Expand All @@ -39,6 +42,7 @@ public class BigintGroupByHash
private static final List<Type> TYPES_WITH_RAW_HASH = ImmutableList.of(BIGINT, BIGINT);

private final int hashChannel;
private final int maskChannel;
private final boolean outputRawHash;

private int maxFill;
Expand All @@ -56,12 +60,13 @@ public class BigintGroupByHash

private int nextGroupId;

public BigintGroupByHash(int hashChannel, boolean outputRawHash, int expectedSize)
public BigintGroupByHash(int hashChannel, Optional<Integer> maskChannel, boolean outputRawHash, int expectedSize)
{
checkArgument(hashChannel >= 0, "hashChannel must be at least zero");
checkArgument(expectedSize > 0, "expectedSize must be greater than zero");

this.hashChannel = hashChannel;
this.maskChannel = checkNotNull(maskChannel, "maskChannel is null").orElse(-1);
this.outputRawHash = outputRawHash;

int hashSize = arraySize(expectedSize, FILL_RATIO);
Expand Down Expand Up @@ -124,9 +129,19 @@ public void addPage(Page page)
{
int positionCount = page.getPositionCount();

Block maskBlock = null;
if (maskChannel >= 0) {
maskBlock = page.getBlock(maskChannel);
}

// get the group id for each position
Block block = page.getBlock(hashChannel);
for (int position = 0; position < positionCount; position++) {
// skip masked rows
if (maskBlock != null && !BOOLEAN.getBoolean(maskBlock, position)) {
continue;
}

// get the group for the current row
putIfAbsent(position, block);
}
Expand All @@ -140,9 +155,20 @@ public GroupByIdBlock getGroupIds(Page page)
// we know the exact size required for the block
BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(positionCount);

Block maskBlock = null;
if (maskChannel >= 0) {
maskBlock = page.getBlock(maskChannel);
}

// get the group id for each position
Block block = page.getBlock(hashChannel);
for (int position = 0; position < positionCount; position++) {
// skip masked rows
if (maskBlock != null && !BOOLEAN.getBoolean(maskBlock, position)) {
blockBuilder.appendNull();
continue;
}

// get the group for the current row
int groupId = putIfAbsent(position, block);

Expand Down
Expand Up @@ -69,7 +69,7 @@ public static class ChannelSetBuilder
public ChannelSetBuilder(Type type, Optional<Integer> hashChannel, int expectedPositions, OperatorContext operatorContext)
{
List<Type> types = ImmutableList.of(type);
this.hash = createGroupByHash(types, new int[] {0}, hashChannel, expectedPositions);
this.hash = createGroupByHash(types, new int[] {0}, Optional.<Integer>empty(), hashChannel, expectedPositions);
this.operatorContext = operatorContext;
this.nullBlockPage = new Page(type.createBlockBuilder(new BlockBuilderStatus(), 1, UNKNOWN.getFixedSize()).appendNull().build());
}
Expand Down
Expand Up @@ -96,7 +96,7 @@ public DistinctLimitOperator(OperatorContext operatorContext, List<Type> types,
for (int channel : distinctChannels) {
distinctTypes.add(types.get(channel));
}
this.groupByHash = createGroupByHash(distinctTypes.build(), Ints.toArray(distinctChannels), hashChannel, Math.min((int) limit, 10_000));
this.groupByHash = createGroupByHash(distinctTypes.build(), Ints.toArray(distinctChannels), Optional.<Integer>empty(), hashChannel, Math.min((int) limit, 10_000));
this.pageBuilder = new PageBuilder(types);
remainingLimit = limit;
}
Expand Down
Expand Up @@ -24,12 +24,12 @@

public interface GroupByHash
{
static GroupByHash createGroupByHash(List<? extends Type> hashTypes, int[] hashChannels, Optional<Integer> inputHashChannel, int expectedSize)
static GroupByHash createGroupByHash(List<? extends Type> hashTypes, int[] hashChannels, Optional<Integer> maskChannel, Optional<Integer> inputHashChannel, int expectedSize)
{
if (hashTypes.size() == 1 && hashTypes.get(0).equals(BIGINT) && hashChannels.length == 1) {
return new BigintGroupByHash(hashChannels[0], inputHashChannel.isPresent(), expectedSize);
return new BigintGroupByHash(hashChannels[0], maskChannel, inputHashChannel.isPresent(), expectedSize);
}
return new MultiChannelGroupByHash(hashTypes, hashChannels, inputHashChannel, expectedSize);
return new MultiChannelGroupByHash(hashTypes, hashChannels, maskChannel, inputHashChannel, expectedSize);
}

long getEstimatedSize();
Expand Down
Expand Up @@ -43,6 +43,7 @@ public static class HashAggregationOperatorFactory
implements OperatorFactory
{
private final int operatorId;
private final Optional<Integer> maskChannel;
private final List<Type> groupByTypes;
private final List<Integer> groupByChannels;
private final Step step;
Expand All @@ -60,11 +61,13 @@ public HashAggregationOperatorFactory(
List<Integer> groupByChannels,
Step step,
List<AccumulatorFactory> accumulatorFactories,
Optional<Integer> maskChannel,
Optional<Integer> hashChannel,
int expectedGroups,
DataSize maxPartialMemory)
{
this.operatorId = operatorId;
this.maskChannel = checkNotNull(maskChannel, "maskChannel is null");
this.hashChannel = checkNotNull(hashChannel, "hashChannel is null");
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
Expand Down Expand Up @@ -94,14 +97,16 @@ public Operator createOperator(DriverContext driverContext)
else {
operatorContext = driverContext.addOperatorContext(operatorId, HashAggregationOperator.class.getSimpleName());
}
return new HashAggregationOperator(
HashAggregationOperator hashAggregationOperator = new HashAggregationOperator(
operatorContext,
groupByTypes,
groupByChannels,
step,
accumulatorFactories,
maskChannel,
hashChannel,
expectedGroups);
return hashAggregationOperator;
}

@Override
Expand All @@ -116,6 +121,7 @@ public void close()
private final List<Integer> groupByChannels;
private final Step step;
private final List<AccumulatorFactory> accumulatorFactories;
private final Optional<Integer> maskChannel;
private final Optional<Integer> hashChannel;
private final int expectedGroups;

Expand All @@ -131,6 +137,7 @@ public HashAggregationOperator(
List<Integer> groupByChannels,
Step step,
List<AccumulatorFactory> accumulatorFactories,
Optional<Integer> maskChannel,
Optional<Integer> hashChannel,
int expectedGroups)
{
Expand All @@ -142,6 +149,7 @@ public HashAggregationOperator(
this.groupByTypes = ImmutableList.copyOf(groupByTypes);
this.groupByChannels = ImmutableList.copyOf(groupByChannels);
this.accumulatorFactories = ImmutableList.copyOf(accumulatorFactories);
this.maskChannel = checkNotNull(maskChannel, "maskChannel is null");
this.hashChannel = checkNotNull(hashChannel, "hashChannel is null");
this.step = step;
this.expectedGroups = expectedGroups;
Expand Down Expand Up @@ -190,6 +198,7 @@ public void addInput(Page page)
expectedGroups,
groupByTypes,
groupByChannels,
maskChannel,
hashChannel,
operatorContext);

Expand Down Expand Up @@ -257,10 +266,11 @@ private GroupByHashAggregationBuilder(
int expectedGroups,
List<Type> groupByTypes,
List<Integer> groupByChannels,
Optional<Integer> maskChannel,
Optional<Integer> hashChannel,
OperatorContext operatorContext)
{
this.groupByHash = createGroupByHash(groupByTypes, Ints.toArray(groupByChannels), hashChannel, expectedGroups);
this.groupByHash = createGroupByHash(groupByTypes, Ints.toArray(groupByChannels), maskChannel, hashChannel, expectedGroups);
this.operatorContext = operatorContext;
this.partial = (step == Step.PARTIAL);

Expand Down

0 comments on commit dc58f4f

Please sign in to comment.