Skip to content

Commit

Permalink
Extract GroupByHash interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Apr 2, 2015
1 parent 5dd9470 commit 66c8a20
Show file tree
Hide file tree
Showing 13 changed files with 396 additions and 331 deletions.
Expand Up @@ -21,6 +21,7 @@
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.type.UnknownType.UNKNOWN;


public class ChannelSet public class ChannelSet
Expand Down Expand Up @@ -68,7 +69,7 @@ public static class ChannelSetBuilder
public ChannelSetBuilder(Type type, Optional<Integer> hashChannel, int expectedPositions, OperatorContext operatorContext) public ChannelSetBuilder(Type type, Optional<Integer> hashChannel, int expectedPositions, OperatorContext operatorContext)
{ {
List<Type> types = ImmutableList.of(type); List<Type> types = ImmutableList.of(type);
this.hash = new GroupByHash(types, new int[] {0}, hashChannel, expectedPositions); this.hash = createGroupByHash(types, new int[] {0}, hashChannel, expectedPositions);
this.operatorContext = operatorContext; this.operatorContext = operatorContext;
this.nullBlockPage = new Page(type.createBlockBuilder(new BlockBuilderStatus(), 1, UNKNOWN.getFixedSize()).appendNull().build()); this.nullBlockPage = new Page(type.createBlockBuilder(new BlockBuilderStatus(), 1, UNKNOWN.getFixedSize()).appendNull().build());
} }
Expand Down
Expand Up @@ -22,6 +22,7 @@
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -95,7 +96,7 @@ public DistinctLimitOperator(OperatorContext operatorContext, List<Type> types,
for (int channel : distinctChannels) { for (int channel : distinctChannels) {
distinctTypes.add(types.get(channel)); distinctTypes.add(types.get(channel));
} }
this.groupByHash = new GroupByHash(distinctTypes.build(), Ints.toArray(distinctChannels), hashChannel, Math.min((int) limit, 10_000)); this.groupByHash = createGroupByHash(distinctTypes.build(), Ints.toArray(distinctChannels), hashChannel, Math.min((int) limit, 10_000));
this.pageBuilder = new PageBuilder(types); this.pageBuilder = new PageBuilder(types);
remainingLimit = limit; remainingLimit = limit;
} }
Expand Down
323 changes: 11 additions & 312 deletions presto-main/src/main/java/com/facebook/presto/operator/GroupByHash.java
Expand Up @@ -15,332 +15,31 @@


import com.facebook.presto.spi.Page; import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;


import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.facebook.presto.operator.SyntheticAddress.decodePosition; public interface GroupByHash
import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex;
import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.sql.gen.JoinCompiler.PagesHashStrategyFactory;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.airlift.slice.SizeOf.sizeOf;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
import static it.unimi.dsi.fastutil.HashCommon.maxFill;
import static it.unimi.dsi.fastutil.HashCommon.murmurHash3;

// This implementation assumes arrays used in the hash are always a power of 2
public class GroupByHash
{ {
private static final JoinCompiler JOIN_COMPILER = new JoinCompiler(); static GroupByHash createGroupByHash(List<? extends Type> hashTypes, int[] hashChannels, Optional<Integer> inputHashChannel, int expectedSize)

private static final float FILL_RATIO = 0.9f;
private final List<Type> types;
private final int[] channels;

private final PagesHashStrategy hashStrategy;
private final List<ObjectArrayList<Block>> channelBuilders;
private final HashGenerator hashGenerator;
private final Optional<Integer> precomputedHashChannel;
private PageBuilder currentPageBuilder;

private long completedPagesMemorySize;

private int maxFill;
private int mask;
private long[] key;
private int[] value;

private final LongBigArray groupAddress;

private int nextGroupId;

public GroupByHash(List<? extends Type> hashTypes, int[] hashChannels, Optional<Integer> inputHashChannel, int expectedSize)
{
checkNotNull(hashTypes, "hashTypes is null");
checkArgument(hashTypes.size() == hashChannels.length, "hashTypes and hashChannels have different sizes");
checkNotNull(inputHashChannel, "inputHashChannel is null");
checkArgument(expectedSize > 0, "expectedSize must be greater than zero");

this.types = inputHashChannel.isPresent() ? ImmutableList.copyOf(Iterables.concat(hashTypes, ImmutableList.of(BIGINT))) : ImmutableList.copyOf(hashTypes);
this.channels = checkNotNull(hashChannels, "hashChannels is null").clone();
this.hashGenerator = inputHashChannel.isPresent() ? new PrecomputedHashGenerator(inputHashChannel.get()) : new InterpretedHashGenerator(ImmutableList.copyOf(hashTypes), hashChannels);

// For each hashed channel, create an appendable list to hold the blocks (builders). As we
// add new values we append them to the existing block builder until it fills up and then
// we add a new block builder to each list.
ImmutableList.Builder<Integer> outputChannels = ImmutableList.builder();
ImmutableList.Builder<ObjectArrayList<Block>> channelBuilders = ImmutableList.builder();
for (int i = 0; i < hashChannels.length; i++) {
outputChannels.add(i);
channelBuilders.add(ObjectArrayList.wrap(new Block[1024], 0));
}
if (inputHashChannel.isPresent()) {
this.precomputedHashChannel = Optional.of(hashChannels.length);
channelBuilders.add(ObjectArrayList.wrap(new Block[1024], 0));
}
else {
this.precomputedHashChannel = Optional.empty();
}
this.channelBuilders = channelBuilders.build();
PagesHashStrategyFactory pagesHashStrategyFactory = JOIN_COMPILER.compilePagesHashStrategyFactory(this.types, outputChannels.build());
hashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(this.channelBuilders, this.precomputedHashChannel);

startNewPage();

// reserve memory for the arrays
int hashSize = arraySize(expectedSize, FILL_RATIO);

maxFill = maxFill(hashSize, FILL_RATIO);
mask = hashSize - 1;
key = new long[hashSize];
Arrays.fill(key, -1);

value = new int[hashSize];

groupAddress = new LongBigArray();
groupAddress.ensureCapacity(maxFill);
}

public long getEstimatedSize()
{
return (sizeOf(channelBuilders.get(0).elements()) * channelBuilders.size()) +
completedPagesMemorySize +
currentPageBuilder.getSizeInBytes() +
sizeOf(key) +
sizeOf(value) +
groupAddress.sizeOf();
}

public List<Type> getTypes()
{
return types;
}

public int getGroupCount()
{
return nextGroupId;
}

public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChannelOffset)
{
long address = groupAddress.get(groupId);
int blockIndex = decodeSliceIndex(address);
int position = decodePosition(address);
hashStrategy.appendTo(blockIndex, position, pageBuilder, outputChannelOffset);
}

public void addPage(Page page)
{ {
Block[] hashBlocks = extractHashColumns(page); return new MultiChannelGroupByHash(hashTypes, hashChannels, inputHashChannel, expectedSize);

// get the group id for each position
int positionCount = page.getPositionCount();
for (int position = 0; position < positionCount; position++) {
// get the group for the current row
putIfAbsent(position, page, hashBlocks);
}
} }


public GroupByIdBlock getGroupIds(Page page) long getEstimatedSize();
{
int positionCount = page.getPositionCount();

// we know the exact size required for the block
BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(positionCount);


// extract the hash columns List<Type> getTypes();
Block[] hashBlocks = extractHashColumns(page);


// get the group id for each position int getGroupCount();
for (int position = 0; position < positionCount; position++) {
// get the group for the current row
int groupId = putIfAbsent(position, page, hashBlocks);

// output the group id for this row
BIGINT.writeLong(blockBuilder, groupId);
}
return new GroupByIdBlock(nextGroupId, blockBuilder.build());
}


public boolean contains(int position, Page page) void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChannelOffset);
{
int rawHash = hashStrategy.hashRow(position, page.getBlocks());
int hashPosition = getHashPosition(rawHash, mask);


// look for a slot containing this key void addPage(Page page);
while (key[hashPosition] != -1) {
long address = key[hashPosition];
if (hashStrategy.positionEqualsRow(decodeSliceIndex(address), decodePosition(address), position, page.getBlocks())) {
// found an existing slot for this key
return true;
}
// increment position and mask to handle wrap around
hashPosition = (hashPosition + 1) & mask;
}


return false; GroupByIdBlock getGroupIds(Page page);
}


public int putIfAbsent(int position, Page page) boolean contains(int position, Page page);
{
return putIfAbsent(position, page, extractHashColumns(page));
}


private int putIfAbsent(int position, Page page, Block[] hashBlocks) int putIfAbsent(int position, Page page);
{
int rawHash = hashGenerator.hashPosition(position, page);
int hashPosition = getHashPosition(rawHash, mask);

// look for an empty slot or a slot containing this key
int groupId = -1;
while (key[hashPosition] != -1) {
long address = key[hashPosition];
if (positionEqualsCurrentRow(decodeSliceIndex(address), decodePosition(address), position, hashBlocks)) {
// found an existing slot for this key
groupId = value[hashPosition];

break;
}
// increment position and mask to handle wrap around
hashPosition = (hashPosition + 1) & mask;
}

// did we find an existing group?
if (groupId < 0) {
groupId = addNewGroup(hashPosition, position, page, rawHash);
}
return groupId;
}

private int addNewGroup(int hashPosition, int position, Page page, int rawHash)
{
// add the row to the open page
Block[] blocks = page.getBlocks();
for (int i = 0; i < channels.length; i++) {
int hashChannel = channels[i];
Type type = types.get(i);
type.appendTo(blocks[hashChannel], position, currentPageBuilder.getBlockBuilder(i));
}
if (precomputedHashChannel.isPresent()) {
BIGINT.writeLong(currentPageBuilder.getBlockBuilder(precomputedHashChannel.get()), rawHash);
}
currentPageBuilder.declarePosition();
int pageIndex = channelBuilders.get(0).size() - 1;
int pagePosition = currentPageBuilder.getPositionCount() - 1;
long address = encodeSyntheticAddress(pageIndex, pagePosition);

// record group id in hash
int groupId = nextGroupId++;

key[hashPosition] = address;
value[hashPosition] = groupId;
groupAddress.set(groupId, address);

// create new page builder if this page is full
if (currentPageBuilder.isFull()) {
startNewPage();
}

// increase capacity, if necessary
if (nextGroupId >= maxFill) {
rehash(maxFill * 2);
}
return groupId;
}

private void startNewPage()
{
if (currentPageBuilder != null) {
completedPagesMemorySize += currentPageBuilder.getSizeInBytes();
}

currentPageBuilder = new PageBuilder(types);
for (int i = 0; i < types.size(); i++) {
channelBuilders.get(i).add(currentPageBuilder.getBlockBuilder(i));
}
}

private void rehash(int size)
{
int newSize = arraySize(size + 1, FILL_RATIO);

int newMask = newSize - 1;
long[] newKey = new long[newSize];
Arrays.fill(newKey, -1);
int[] newValue = new int[newSize];

int oldIndex = 0;
for (int groupId = 0; groupId < nextGroupId; groupId++) {
// seek to the next used slot
while (key[oldIndex] == -1) {
oldIndex++;
}

// get the address for this slot
long address = key[oldIndex];

// find an empty slot for the address
int pos = getHashPosition(hashPosition(address), newMask);
while (newKey[pos] != -1) {
pos = (pos + 1) & newMask;
}

// record the mapping
newKey[pos] = address;
newValue[pos] = value[oldIndex];
oldIndex++;
}

this.mask = newMask;
this.maxFill = maxFill(newSize, FILL_RATIO);
this.key = newKey;
this.value = newValue;
groupAddress.ensureCapacity(maxFill);
}

private Block[] extractHashColumns(Page page)
{
Block[] hashBlocks = new Block[channels.length];
for (int i = 0; i < channels.length; i++) {
hashBlocks[i] = page.getBlock(channels[i]);
}
return hashBlocks;
}

private int hashPosition(long sliceAddress)
{
int sliceIndex = decodeSliceIndex(sliceAddress);
int position = decodePosition(sliceAddress);
if (precomputedHashChannel.isPresent()) {
return getRawHash(sliceIndex, position);
}
return hashStrategy.hashPosition(sliceIndex, position);
}

private int getRawHash(int sliceIndex, int position)
{
return (int) channelBuilders.get(precomputedHashChannel.get()).get(sliceIndex).getLong(position, 0);
}

private boolean positionEqualsCurrentRow(int sliceIndex, int slicePosition, int position, Block[] blocks)
{
return hashStrategy.positionEqualsRow(sliceIndex, slicePosition, position, blocks);
}

private static int getHashPosition(int rawHash, int mask)
{
return murmurHash3(rawHash) & mask;
}
} }
Expand Up @@ -31,6 +31,7 @@
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;


import static com.facebook.presto.operator.GroupByHash.createGroupByHash;
import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -268,7 +269,7 @@ private GroupByHashAggregationBuilder(
Optional<Integer> hashChannel, Optional<Integer> hashChannel,
MemoryManager memoryManager) MemoryManager memoryManager)
{ {
this.groupByHash = new GroupByHash(groupByTypes, Ints.toArray(groupByChannels), hashChannel, expectedGroups); this.groupByHash = createGroupByHash(groupByTypes, Ints.toArray(groupByChannels), hashChannel, expectedGroups);
this.memoryManager = memoryManager; this.memoryManager = memoryManager;


// wrapper each function with an aggregator // wrapper each function with an aggregator
Expand Down

0 comments on commit 66c8a20

Please sign in to comment.