Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TopN row number / rank performance #16753

Merged
merged 4 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.annotations.VisibleForTesting;
import io.trino.array.LongBigArray;
import io.trino.spi.Page;
import io.trino.util.HeapTraversal;
import io.trino.util.LongBigArrayFIFOQueue;

Expand Down Expand Up @@ -92,6 +93,31 @@ public long sizeOf()
+ peerGroupLookup.sizeOf();
}

public int findFirstPositionToAdd(Page newPage, GroupByIdBlock groupIds, PageWithPositionComparator comparator, RowReferencePageManager pageManager)
{
long currentGroups = groupIdToHeapBuffer.getTotalGroups();
groupIdToHeapBuffer.allocateGroupIfNeeded(groupIds.getGroupCount());

for (int position = 0; position < newPage.getPositionCount(); position++) {
long groupId = groupIds.getGroupId(position);
if (groupId >= currentGroups || groupIdToHeapBuffer.getHeapValueCount(groupId) < topN) {
return position;
}
long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
if (heapRootNodeIndex == UNKNOWN_INDEX) {
return position;
}
long rightPageRowId = peekRootRowIdByHeapNodeIndex(heapRootNodeIndex);
Page rightPage = pageManager.getPage(rightPageRowId);
int rightPosition = pageManager.getPosition(rightPageRowId);
// If the current position is equal to or less than the current heap root index, then we may need to insert it
if (comparator.compareTo(newPage, position, rightPage, rightPosition) <= 0) {
return position;
}
}
return -1;
}

/**
* Add the specified row to this accumulator.
* <p>
Expand All @@ -105,7 +131,7 @@ public boolean add(long groupId, RowReference rowReference)
long peerHeapNodeIndex = peerGroupLookup.get(groupId, rowReference);
if (peerHeapNodeIndex != UNKNOWN_INDEX) {
directPeerGroupInsert(groupId, peerHeapNodeIndex, rowReference.allocateRowId());
if (calculateRootRank(groupId) > topN) {
if (calculateRootRank(groupId, groupIdToHeapBuffer.getHeapRootNodeIndex(groupId)) > topN) {
heapPop(groupId, rowIdEvictionListener);
}
// Return true because heapPop is guaranteed not to evict the newly inserted row (by definition of rank)
Expand All @@ -119,11 +145,12 @@ public boolean add(long groupId, RowReference rowReference)
heapInsert(groupId, newPeerGroupIndex, 1);
return true;
}
if (rowReference.compareTo(strategy, peekRootRowId(groupId)) < 0) {
long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
if (rowReference.compareTo(strategy, peekRootRowIdByHeapNodeIndex(heapRootNodeIndex)) < 0) {
// Given that total number of values >= topN, we can only consider values that are less than the root (otherwise topN would be violated)
long newPeerGroupIndex = peerGroupBuffer.allocateNewNode(rowReference.allocateRowId(), UNKNOWN_INDEX);
// Rank will increase by +1 after insertion, so only need to pop if root rank is already == topN.
if (calculateRootRank(groupId) < topN) {
if (calculateRootRank(groupId, heapRootNodeIndex) < topN) {
heapInsert(groupId, newPeerGroupIndex, 1);
}
else {
Expand Down Expand Up @@ -158,7 +185,7 @@ public long drainTo(long groupId, LongBigArray rowIdOutput, LongBigArray ranking
long peerGroupIndex = heapNodeBuffer.getPeerGroupIndex(heapRootNodeIndex);
verify(peerGroupIndex != UNKNOWN_INDEX, "Peer group should have at least one value");

long rank = calculateRootRank(groupId);
long rank = calculateRootRank(groupId, heapRootNodeIndex);
do {
rowIdOutput.set(insertionIndex, peerGroupBuffer.getRowId(peerGroupIndex));
rankingOutput.set(insertionIndex, rank);
Expand Down Expand Up @@ -206,10 +233,9 @@ public long drainTo(long groupId, LongBigArray rowIdOutput)
return valueCount;
}

private long calculateRootRank(long groupId)
private long calculateRootRank(long groupId, long heapRootIndex)
{
long heapValueCount = groupIdToHeapBuffer.getHeapValueCount(groupId);
long heapRootIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
checkArgument(heapRootIndex != UNKNOWN_INDEX, "Group does not have a root");
long rootPeerGroupCount = heapNodeBuffer.getPeerGroupCount(heapRootIndex);
return heapValueCount - rootPeerGroupCount + 1;
Expand All @@ -224,9 +250,8 @@ private void directPeerGroupInsert(long groupId, long heapNodeIndex, long rowId)
groupIdToHeapBuffer.incrementHeapValueCount(groupId);
}

private long peekRootRowId(long groupId)
private long peekRootRowIdByHeapNodeIndex(long heapRootNodeIndex)
{
long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
checkArgument(heapRootNodeIndex != UNKNOWN_INDEX, "Group has nothing to peek");
return peerGroupBuffer.getRowId(heapNodeBuffer.getPeerGroupIndex(heapRootNodeIndex));
}
Expand Down Expand Up @@ -487,7 +512,7 @@ void verifyIntegrity()
long heapSize = groupIdToHeapBuffer.getHeapSize(groupId);
long heapValueCount = groupIdToHeapBuffer.getHeapValueCount(groupId);
long rootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
verify(rootNodeIndex == UNKNOWN_INDEX || calculateRootRank(rootNodeIndex) <= topN, "Max heap has more values than needed");
verify(rootNodeIndex == UNKNOWN_INDEX || calculateRootRank(groupId, rootNodeIndex) <= topN, "Max heap has more values than needed");
IntegrityStats integrityStats = verifyHeapIntegrity(groupId, rootNodeIndex);
verify(integrityStats.getPeerGroupCount() == heapSize, "Recorded heap size does not match actual heap size");
totalHeapNodes += integrityStats.getPeerGroupCount();
Expand Down Expand Up @@ -577,7 +602,7 @@ public long getValueCount()
/**
* Buffer abstracting a mapping from group ID to a heap. The group ID provides the index for all operations.
*/
private static class GroupIdToHeapBuffer
private static final class GroupIdToHeapBuffer
{
private static final long INSTANCE_SIZE = instanceSize(GroupIdToHeapBuffer.class);
private static final int METRICS_POSITIONS_PER_ENTRY = 2;
Expand All @@ -604,9 +629,12 @@ private static class GroupIdToHeapBuffer

public void allocateGroupIfNeeded(long groupId)
{
if (totalGroups > groupId) {
return;
}
// Group IDs generated by GroupByHash are always generated consecutively starting from 0, so observing a
// group ID N means groups [0, N] inclusive must exist.
totalGroups = max(groupId + 1, totalGroups);
totalGroups = groupId + 1;
heapIndexBuffer.ensureCapacity(totalGroups);
metricsBuffer.ensureCapacity(totalGroups * METRICS_POSITIONS_PER_ENTRY);
}
Expand Down Expand Up @@ -675,7 +703,7 @@ public long sizeOf()
/**
* Buffer abstracting storage of nodes in the heap. Nodes are referenced by their node index for operations.
*/
private static class HeapNodeBuffer
private static final class HeapNodeBuffer
{
private static final long INSTANCE_SIZE = instanceSize(HeapNodeBuffer.class);
private static final int POSITIONS_PER_ENTRY = 4;
Expand Down Expand Up @@ -790,7 +818,7 @@ public long sizeOf()
* Buffer abstracting storage of peer groups as linked chains of matching values. Peer groups are referenced by
* their node index for operations.
*/
private static class PeerGroupBuffer
private static final class PeerGroupBuffer
{
private static final long INSTANCE_SIZE = instanceSize(PeerGroupBuffer.class);
private static final int POSITIONS_PER_ENTRY = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class GroupedTopNRankBuilder
private final List<Type> sourceTypes;
private final boolean produceRanking;
private final GroupByHash groupByHash;
private final PageWithPositionComparator comparator;
private final RowReferencePageManager pageManager = new RowReferencePageManager();
private final GroupedTopNRankAccumulator groupedTopNRankAccumulator;

Expand All @@ -58,7 +59,7 @@ public GroupedTopNRankBuilder(
this.produceRanking = produceRanking;
this.groupByHash = requireNonNull(groupByHash, "groupByHash is null");

requireNonNull(comparator, "comparator is null");
this.comparator = requireNonNull(comparator, "comparator is null");
requireNonNull(equalsAndHash, "equalsAndHash is null");
groupedTopNRankAccumulator = new GroupedTopNRankAccumulator(
new RowIdComparisonHashStrategy()
Expand Down Expand Up @@ -123,8 +124,13 @@ public long getEstimatedSizeInBytes()

private void processPage(Page newPage, GroupByIdBlock groupIds)
{
try (LoadCursor loadCursor = pageManager.add(newPage)) {
for (int position = 0; position < newPage.getPositionCount(); position++) {
int firstPositionToAdd = groupedTopNRankAccumulator.findFirstPositionToAdd(newPage, groupIds, comparator, pageManager);
if (firstPositionToAdd < 0) {
return;
}

try (LoadCursor loadCursor = pageManager.add(newPage, firstPositionToAdd)) {
for (int position = firstPositionToAdd; position < newPage.getPositionCount(); position++) {
long groupId = groupIds.getGroupId(position);
loadCursor.advance();
groupedTopNRankAccumulator.add(groupId, loadCursor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.annotations.VisibleForTesting;
import io.trino.array.LongBigArray;
import io.trino.spi.Page;
import io.trino.util.HeapTraversal;
import io.trino.util.LongBigArrayFIFOQueue;

Expand Down Expand Up @@ -72,6 +73,30 @@ public long sizeOf()
return INSTANCE_SIZE + groupIdToHeapBuffer.sizeOf() + heapNodeBuffer.sizeOf() + heapTraversal.sizeOf();
}

public int findFirstPositionToAdd(Page newPage, GroupByIdBlock groupIds, PageWithPositionComparator comparator, RowReferencePageManager pageManager)
{
long currentTotalGroups = groupIdToHeapBuffer.getTotalGroups();
groupIdToHeapBuffer.allocateGroupIfNeeded(groupIds.getGroupCount());

for (int position = 0; position < newPage.getPositionCount(); position++) {
long groupId = groupIds.getGroupId(position);
if (groupId >= currentTotalGroups || calculateRootRowNumber(groupId) < topN) {
return position;
}
long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
if (heapRootNodeIndex == UNKNOWN_INDEX) {
return position;
}
long rowId = heapNodeBuffer.getRowId(heapRootNodeIndex);
Page rightPage = pageManager.getPage(rowId);
int rightPosition = pageManager.getPosition(rowId);
if (comparator.compareTo(newPage, position, rightPage, rightPosition) < 0) {
return position;
}
}
return -1;
}

/**
* Add the specified row to this accumulator.
* <p>
Expand Down Expand Up @@ -325,7 +350,7 @@ void verifyIntegrity()
for (long groupId = 0; groupId < groupIdToHeapBuffer.getTotalGroups(); groupId++) {
long heapSize = groupIdToHeapBuffer.getHeapSize(groupId);
long rootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
verify(rootNodeIndex == UNKNOWN_INDEX || calculateRootRowNumber(rootNodeIndex) <= topN, "Max heap has more values than needed");
verify(rootNodeIndex == UNKNOWN_INDEX || calculateRootRowNumber(groupId) <= topN, "Max heap has more values than needed");
IntegrityStats integrityStats = verifyHeapIntegrity(rootNodeIndex);
verify(integrityStats.getNodeCount() == heapSize, "Recorded heap size does not match actual heap size");
totalHeapNodes += integrityStats.getNodeCount();
Expand Down Expand Up @@ -411,9 +436,12 @@ private static class GroupIdToHeapBuffer

public void allocateGroupIfNeeded(long groupId)
{
if (totalGroups > groupId) {
return;
}
// Group IDs generated by GroupByHash are always generated consecutively starting from 0, so observing a
// group ID N means groups [0, N] inclusive must exist.
totalGroups = max(groupId + 1, totalGroups);
totalGroups = groupId + 1;
heapIndexBuffer.ensureCapacity(totalGroups);
sizeBuffer.ensureCapacity(totalGroups);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class GroupedTopNRowNumberBuilder
private final GroupByHash groupByHash;
private final RowReferencePageManager pageManager = new RowReferencePageManager();
private final GroupedTopNRowNumberAccumulator groupedTopNRowNumberAccumulator;
private final PageWithPositionComparator comparator;

public GroupedTopNRowNumberBuilder(
List<Type> sourceTypes,
Expand All @@ -57,7 +58,7 @@ public GroupedTopNRowNumberBuilder(
this.produceRowNumber = produceRowNumber;
this.groupByHash = requireNonNull(groupByHash, "groupByHash is null");

requireNonNull(comparator, "comparator is null");
this.comparator = requireNonNull(comparator, "comparator is null");
groupedTopNRowNumberAccumulator = new GroupedTopNRowNumberAccumulator(
(leftRowId, rightRowId) -> {
Page leftPage = pageManager.getPage(leftRowId);
Expand Down Expand Up @@ -98,8 +99,13 @@ public long getEstimatedSizeInBytes()

private void processPage(Page newPage, GroupByIdBlock groupIds)
{
try (LoadCursor loadCursor = pageManager.add(newPage)) {
for (int position = 0; position < newPage.getPositionCount(); position++) {
int firstPositionToAdd = groupedTopNRowNumberAccumulator.findFirstPositionToAdd(newPage, groupIds, comparator, pageManager);
if (firstPositionToAdd < 0) {
return;
}

try (LoadCursor loadCursor = pageManager.add(newPage, firstPositionToAdd)) {
for (int position = firstPositionToAdd; position < newPage.getPositionCount(); position++) {
long groupId = groupIds.getGroupId(position);
loadCursor.advance();
groupedTopNRowNumberAccumulator.add(groupId, loadCursor);
Expand Down
17 changes: 9 additions & 8 deletions core/trino-main/src/main/java/io/trino/operator/IdRegistry.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* <p>
* This class may recycle deallocated IDs for new allocations.
*/
public class IdRegistry<T>
public final class IdRegistry<T>
{
private static final long INSTANCE_SIZE = instanceSize(IdRegistry.class);

Expand All @@ -38,18 +38,19 @@ public class IdRegistry<T>
*
* @return ID referencing the provided object
*/
public int allocateId(IntFunction<T> factory)
public T allocateId(IntFunction<T> factory)
{
int newId;
T result;
if (!emptySlots.isEmpty()) {
newId = emptySlots.dequeueInt();
objects.set(newId, factory.apply(newId));
int id = emptySlots.dequeueInt();
result = factory.apply(id);
objects.set(id, result);
}
else {
newId = objects.size();
objects.add(factory.apply(newId));
result = factory.apply(objects.size());
objects.add(result);
}
return newId;
return result;
}

public void deallocate(int id)
Expand Down