Skip to content

Commit

Permalink
Use 64bit hashes for join/group by
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedukow authored and martint committed Apr 12, 2016
1 parent 7c159ec commit 8777b5d
Show file tree
Hide file tree
Showing 71 changed files with 150 additions and 149 deletions.
Empty file added dummy
Empty file.
Expand Up @@ -32,7 +32,6 @@

import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.StandardErrorCode.NO_NODES_AVAILABLE;
import static java.lang.Math.abs;
import static java.util.Objects.requireNonNull;

public class BlackHoleNodePartitioningProvider
Expand Down Expand Up @@ -84,12 +83,16 @@ public BucketFunction getBucketFunction(
List<Type> partitionChannelTypes, int bucketCount)
{
return (page, position) -> {
int hash = 13;
long hash = 13;
for (int i = 0; i < partitionChannelTypes.size(); i++) {
Type type = partitionChannelTypes.get(i);
hash = 31 * hash + type.hash(page.getBlock(i), position);
}
return abs(hash) % bucketCount;

// clear the sign bit
hash &= 0x7fff_ffff_ffff_ffffL;

return (int) (hash % bucketCount);
};
}
}
Expand Up @@ -191,7 +191,7 @@ public boolean contains(int position, Page page, int[] hashChannels)
}

long value = BIGINT.getLong(block, position);
int hashPosition = getHashPosition(value, mask);
long hashPosition = getHashPosition(value, mask);

// look for an empty slot or a slot containing this key
while (true) {
Expand Down Expand Up @@ -227,7 +227,7 @@ private int putIfAbsent(int position, Block block)
}

long value = BIGINT.getLong(block, position);
int hashPosition = getHashPosition(value, mask);
long hashPosition = getHashPosition(value, mask);

// look for an empty slot or a slot containing this key
while (true) {
Expand All @@ -247,7 +247,7 @@ private int putIfAbsent(int position, Block block)
return addNewGroup(hashPosition, value);
}

private int addNewGroup(int hashPosition, long value)
private int addNewGroup(long hashPosition, long value)
{
// record group id in hash
int groupId = nextGroupId++;
Expand Down Expand Up @@ -284,7 +284,7 @@ private void rehash()
long value = valuesByGroupId.get(groupId);

// find an empty slot for the address
int hashPosition = getHashPosition(value, newMask);
long hashPosition = getHashPosition(value, newMask);
while (newGroupIds.get(hashPosition) != -1) {
hashPosition = (hashPosition + 1) & newMask;
}
Expand All @@ -303,9 +303,9 @@ private void rehash()
this.valuesByGroupId.ensureCapacity(maxFill);
}

private static int getHashPosition(long rawHash, int mask)
private static long getHashPosition(long rawHash, int mask)
{
return ((int) murmurHash3(rawHash)) & mask;
return murmurHash3(rawHash) & mask;
}

private static int calculateMaxFill(int hashSize)
Expand Down
Expand Up @@ -147,7 +147,7 @@ public boolean equals(int position, int offset, Block otherBlock, int otherPosit
}

@Override
public int hash(int position, int offset, int length)
public long hash(int position, int offset, int length)
{
return block.hash(position, offset, length);
}
Expand Down
Expand Up @@ -19,16 +19,17 @@

public interface HashGenerator
{
int hashPosition(int position, Page page);
long hashPosition(int position, Page page);

default int getPartition(int partitionCount, int position, Page page)
{
int rawHash = hashPosition(position, page);
long rawHash = hashPosition(position, page);

// clear the sign bit
rawHash &= 0x7fff_ffffL;
rawHash &= 0x7fff_ffff_ffff_ffffL;

int partition = (int) (rawHash % partitionCount);

int partition = rawHash % partitionCount;
checkState(partition >= 0 && partition < partitionCount);
return partition;
}
Expand Down
Expand Up @@ -199,10 +199,10 @@ public void addInput(Page page)
maskBuilders[i] = BOOLEAN.createBlockBuilder(new BlockBuilderStatus(), page.getPositionCount());
}
for (int position = 0; position < page.getPositionCount(); position++) {
int rawHash = hashGenerator.hashPosition(position, page);
long rawHash = hashGenerator.hashPosition(position, page);
// mix the bits so we don't use the same hash used to distribute between stages
rawHash = (int) XxHash64.hash(Integer.reverse(rawHash));
rawHash &= Integer.MAX_VALUE;
rawHash = XxHash64.hash(Long.reverse(rawHash));
rawHash &= Long.MAX_VALUE;

boolean active = (rawHash % partitionCount == partition);
BOOLEAN.writeBoolean(activePositions, active);
Expand Down
Expand Up @@ -63,7 +63,7 @@ public InMemoryJoinHash(LongArrayList addresses, PagesHashStrategy pagesHashStra

// index pages
for (int position = 0; position < addresses.size(); position++) {
int pos = getHashPosition(hashPosition(position), mask);
int pos = (int) getHashPosition(hashPosition(position), mask);

// look for an empty slot or a slot containing this key
while (key[pos] != -1) {
Expand Down Expand Up @@ -109,9 +109,9 @@ public long getJoinPosition(int position, Page page)
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
public long getJoinPosition(int position, Page page, long rawHash)
{
int pos = getHashPosition(rawHash, mask);
int pos = (int) getHashPosition(rawHash, mask);

while (key[pos] != -1) {
if (positionEqualsCurrentRow(key[pos], position, page.getBlocks())) {
Expand Down Expand Up @@ -144,7 +144,7 @@ public void close()
{
}

private int hashPosition(int position)
private long hashPosition(int position)
{
long pageAddress = addresses.getLong(position);
int blockIndex = decodeSliceIndex(pageAddress);
Expand Down Expand Up @@ -175,8 +175,8 @@ private boolean positionEqualsPosition(int leftPosition, int rightPosition)
return pagesHashStrategy.positionEqualsPosition(leftBlockIndex, leftBlockPosition, rightBlockIndex, rightBlockPosition);
}

private static int getHashPosition(int rawHash, int mask)
private static long getHashPosition(long rawHash, long mask)
{
return ((int) XxHash64.hash(rawHash)) & mask;
return (XxHash64.hash(rawHash)) & mask;
}
}
Expand Up @@ -41,13 +41,13 @@ public InterpretedHashGenerator(List<Type> hashChannelTypes, int[] hashChannels)
}

@Override
public int hashPosition(int position, Page page)
public long hashPosition(int position, Page page)
{
Block[] blocks = page.getBlocks();
int result = HashGenerationOptimizer.INITIAL_HASH_VALUE;
long result = HashGenerationOptimizer.INITIAL_HASH_VALUE;
for (int i = 0; i < hashChannels.length; i++) {
Type type = hashChannelTypes.get(i);
result = (int) CombineHashFunction.getHash(result, TypeUtils.hashPosition(type, blocks[hashChannels[i]], position));
result = CombineHashFunction.getHash(result, TypeUtils.hashPosition(type, blocks[hashChannels[i]], position));
}
return result;
}
Expand Down
Expand Up @@ -364,7 +364,7 @@ public long getNextJoinPosition(long currentPosition)
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
public long getJoinPosition(int position, Page page, long rawHash)
{
return lookupSource.getJoinPosition(position, page, rawHash);
}
Expand Down
Expand Up @@ -27,7 +27,7 @@ public interface LookupSource

int getJoinPositionCount();

long getJoinPosition(int position, Page page, int rawHash);
long getJoinPosition(int position, Page page, long rawHash);

long getJoinPosition(int position, Page page);

Expand Down
Expand Up @@ -233,8 +233,8 @@ public GroupByIdBlock getGroupIds(Page page)
@Override
public boolean contains(int position, Page page, int[] hashChannels)
{
int rawHash = hashStrategy.hashRow(position, page.getBlocks());
int hashPosition = getHashPosition(rawHash, mask);
long rawHash = hashStrategy.hashRow(position, page.getBlocks());
int hashPosition = (int) getHashPosition(rawHash, mask);

// look for a slot containing this key
while (groupAddressByHash[hashPosition] != -1) {
Expand All @@ -252,13 +252,13 @@ public boolean contains(int position, Page page, int[] hashChannels)
@Override
public int putIfAbsent(int position, Page page)
{
int rawHash = hashGenerator.hashPosition(position, page);
long rawHash = hashGenerator.hashPosition(position, page);
return putIfAbsent(position, page, rawHash);
}

private int putIfAbsent(int position, Page page, int rawHash)
private int putIfAbsent(int position, Page page, long rawHash)
{
int hashPosition = getHashPosition(rawHash, mask);
int hashPosition = (int) getHashPosition(rawHash, mask);

// look for an empty slot or a slot containing this key
int groupId = -1;
Expand All @@ -280,7 +280,7 @@ private int putIfAbsent(int position, Page page, int rawHash)
return groupId;
}

private int addNewGroup(int hashPosition, int position, Page page, int rawHash)
private int addNewGroup(int hashPosition, int position, Page page, long rawHash)
{
// add the row to the open page
for (int i = 0; i < channels.length; i++) {
Expand Down Expand Up @@ -352,9 +352,9 @@ private void rehash()
// get the address for this slot
long address = groupAddressByHash[oldIndex];

int rawHash = hashPosition(address);
long rawHash = hashPosition(address);
// find an empty slot for the address
int pos = getHashPosition(rawHash, newMask);
int pos = (int) getHashPosition(rawHash, newMask);
while (newKey[pos] != -1) {
pos = (pos + 1) & newMask;
}
Expand All @@ -374,7 +374,7 @@ private void rehash()
groupAddressByGroupId.ensureCapacity(maxFill);
}

private int hashPosition(long sliceAddress)
private long hashPosition(long sliceAddress)
{
int sliceIndex = decodeSliceIndex(sliceAddress);
int position = decodePosition(sliceAddress);
Expand All @@ -384,9 +384,9 @@ private int hashPosition(long sliceAddress)
return hashStrategy.hashPosition(sliceIndex, position);
}

private int getRawHash(int sliceIndex, int position)
private long getRawHash(int sliceIndex, int position)
{
return (int) channelBuilders.get(precomputedHashChannel.get()).get(sliceIndex).getLong(position, 0);
return channelBuilders.get(precomputedHashChannel.get()).get(sliceIndex).getLong(position, 0);
}

private boolean positionEqualsCurrentRow(long address, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels)
Expand All @@ -397,7 +397,7 @@ private boolean positionEqualsCurrentRow(long address, int hashPosition, int pos
return hashStrategy.positionEqualsRow(decodeSliceIndex(address), decodePosition(address), position, page, hashChannels);
}

private static int getHashPosition(int rawHash, int mask)
private static long getHashPosition(long rawHash, int mask)
{
return murmurHash3(rawHash) & mask;
}
Expand Down
Expand Up @@ -38,13 +38,13 @@ public interface PagesHashStrategy
/**
* Calculates the hash code the hashed columns in this PagesHashStrategy at the specified position.
*/
int hashPosition(int blockIndex, int position);
long hashPosition(int blockIndex, int position);

/**
* Calculates the hash code at {@code position} in {@code blocks}. Blocks must have the same number of
* entries as the hashed columns and each entry is expected to be the same type.
*/
int hashRow(int position, Block... blocks);
long hashRow(int position, Block... blocks);

/**
* Compares the values in the specified blocks. The values are compared positionally, so {@code leftBlocks}
Expand Down
Expand Up @@ -257,8 +257,8 @@ public void addInput(Page page)
// build a block containing the partition id of each position
BlockBuilder blockBuilder = BIGINT.createBlockBuilder(new BlockBuilderStatus(), page.getPositionCount());
for (int position = 0; position < page.getPositionCount(); position++) {
int rawHash = hashGenerator.hashPosition(position, page);
int partition = murmurHash3(rawHash) & parallelStreamMask;
long rawHash = hashGenerator.hashPosition(position, page);
int partition = (int) (murmurHash3(rawHash) & parallelStreamMask);
BIGINT.writeLong(blockBuilder, partition);
}
Block partitionIds = blockBuilder.build();
Expand Down
Expand Up @@ -69,9 +69,9 @@ public long getJoinPosition(int position, Page page)
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
public long getJoinPosition(int position, Page page, long rawHash)
{
int partition = murmurHash3(rawHash) & partitionMask;
int partition = (int) murmurHash3(rawHash) & partitionMask;
LookupSource lookupSource = lookupSources[partition];
long joinPosition = lookupSource.getJoinPosition(position, page, rawHash);
if (joinPosition < 0) {
Expand Down
Expand Up @@ -28,9 +28,9 @@ public PrecomputedHashGenerator(int hashChannel)
}

@Override
public int hashPosition(int position, Page page)
public long hashPosition(int position, Page page)
{
return (int) BigintType.BIGINT.getLong(page.getBlock(hashChannel), position);
return BigintType.BIGINT.getLong(page.getBlock(hashChannel), position);
}

@Override
Expand Down
Expand Up @@ -56,7 +56,7 @@ public long getInMemorySizeInBytes()
}

@Override
public long getJoinPosition(int position, Page page, int rawHash)
public long getJoinPosition(int position, Page page, long rawHash)
{
return lookupSource.getJoinPosition(position, page, rawHash);
}
Expand Down
Expand Up @@ -106,7 +106,7 @@ public long getCurrentJoinPosition()
return -1;
}
if (probeHashBlock.isPresent()) {
int rawHash = (int) BIGINT.getLong(probeHashBlock.get(), position);
long rawHash = BIGINT.getLong(probeHashBlock.get(), position);
return lookupSource.getJoinPosition(position, probePage, rawHash);
}
return lookupSource.getJoinPosition(position, probePage);
Expand Down
Expand Up @@ -77,12 +77,12 @@ public void appendTo(int blockIndex, int position, PageBuilder pageBuilder, int
}

@Override
public int hashPosition(int blockIndex, int position)
public long hashPosition(int blockIndex, int position)
{
if (precomputedHashChannel != null) {
return (int) BIGINT.getLong(precomputedHashChannel.get(blockIndex), position);
return BIGINT.getLong(precomputedHashChannel.get(blockIndex), position);
}
int result = 0;
long result = 0;
for (int hashChannel : hashChannels) {
Type type = types.get(hashChannel);
Block block = channels.get(hashChannel).get(blockIndex);
Expand All @@ -92,9 +92,9 @@ public int hashPosition(int blockIndex, int position)
}

@Override
public int hashRow(int position, Block... blocks)
public long hashRow(int position, Block... blocks)
{
int result = 0;
long result = 0;
for (int i = 0; i < hashChannels.size(); i++) {
int hashChannel = hashChannels.get(i);
Type type = types.get(hashChannel);
Expand Down

0 comments on commit 8777b5d

Please sign in to comment.