Skip to content

Commit

Permalink
Fix WindowOperator with proper DISTINCT peer group semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
erichwang authored and sopel39 committed Jan 7, 2021
1 parent 03c9ce3 commit 93ea855
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 42 deletions.
Expand Up @@ -393,7 +393,7 @@ public void sort(List<Integer> sortChannels, List<SortOrder> sortOrders, int sta
createPagesIndexComparator(sortChannels, sortOrders).sort(this, startPosition, endPosition);
}

public boolean positionEqualsPosition(PagesHashStrategy partitionHashStrategy, int leftPosition, int rightPosition)
public boolean positionNotDistinctFromPosition(PagesHashStrategy partitionHashStrategy, int leftPosition, int rightPosition)
{
long leftAddress = valueAddresses.getLong(leftPosition);
int leftPageIndex = decodeSliceIndex(leftAddress);
Expand All @@ -403,16 +403,16 @@ public boolean positionEqualsPosition(PagesHashStrategy partitionHashStrategy, i
int rightPageIndex = decodeSliceIndex(rightAddress);
int rightPagePosition = decodePosition(rightAddress);

return partitionHashStrategy.positionEqualsPosition(leftPageIndex, leftPagePosition, rightPageIndex, rightPagePosition);
return partitionHashStrategy.positionNotDistinctFromPosition(leftPageIndex, leftPagePosition, rightPageIndex, rightPagePosition);
}

public boolean positionEqualsRow(PagesHashStrategy pagesHashStrategy, int indexPosition, int rightPosition, Page rightPage)
public boolean positionNotDistinctFromRow(PagesHashStrategy pagesHashStrategy, int indexPosition, int rightPosition, Page rightPage)
{
long pageAddress = valueAddresses.getLong(indexPosition);
int pageIndex = decodeSliceIndex(pageAddress);
int pagePosition = decodePosition(pageAddress);

return pagesHashStrategy.positionEqualsRow(pageIndex, pagePosition, rightPosition, rightPage);
return pagesHashStrategy.positionNotDistinctFromRow(pageIndex, pagePosition, rightPosition, rightPage);
}

private PagesIndexOrdering createPagesIndexComparator(List<Integer> sortChannels, List<SortOrder> sortOrders)
Expand Down
Expand Up @@ -818,12 +818,12 @@ private int updatePagesIndex(PagesIndexWithHashStrategies pagesIndexWithHashStra
PagesIndex pagesIndex = pagesIndexWithHashStrategies.pagesIndex;
PagesHashStrategy preGroupedPartitionHashStrategy = pagesIndexWithHashStrategies.preGroupedPartitionHashStrategy;
if (currentSpillGroupRowPage.isPresent()) {
if (!preGroupedPartitionHashStrategy.rowEqualsRow(0, currentSpillGroupRowPage.get().getColumns(pagesIndexWithHashStrategies.preGroupedPartitionChannels), startPosition, preGroupedPage)) {
if (!preGroupedPartitionHashStrategy.rowNotDistinctFromRow(0, currentSpillGroupRowPage.get().getColumns(pagesIndexWithHashStrategies.preGroupedPartitionChannels), startPosition, preGroupedPage)) {
return startPosition;
}
}

if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(preGroupedPartitionHashStrategy, 0, startPosition, preGroupedPage)) {
if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionNotDistinctFromRow(preGroupedPartitionHashStrategy, 0, startPosition, preGroupedPage)) {
// Find the position where the pre-grouped columns change
int groupEnd = findGroupEnd(preGroupedPage, preGroupedPartitionHashStrategy, startPosition);

Expand Down Expand Up @@ -859,7 +859,7 @@ private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy,
checkArgument(page.getPositionCount() > 0, "Must have at least one position");
checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds");

return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page));
return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowNotDistinctFromRow(firstPosition, page, secondPosition, page));
}

// Assumes input grouped on relevant pagesHashStrategy columns
Expand All @@ -868,7 +868,7 @@ private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHa
checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position");
checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds");

return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition));
return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionNotDistinctFromPosition(pagesHashStrategy, firstPosition, secondPosition));
}

/**
Expand Down
Expand Up @@ -117,15 +117,15 @@ public WindowPartition(

seekGroupStart = position -> {
requireNonNull(position, "position is null");
while (position > 0 && pagesIndex.positionEqualsPosition(peerGroupHashStrategy, partitionStart + position, partitionStart + position - 1)) {
while (position > 0 && pagesIndex.positionNotDistinctFromPosition(peerGroupHashStrategy, partitionStart + position, partitionStart + position - 1)) {
position--;
}
return position;
};

seekGroupEnd = position -> {
requireNonNull(position, "position is null");
while (position < partitionEnd - 1 - partitionStart && pagesIndex.positionEqualsPosition(peerGroupHashStrategy, partitionStart + position, partitionStart + position + 1)) {
while (position < partitionEnd - 1 - partitionStart && pagesIndex.positionNotDistinctFromPosition(peerGroupHashStrategy, partitionStart + position, partitionStart + position + 1)) {
position++;
}
return position;
Expand Down Expand Up @@ -241,7 +241,7 @@ private void updatePeerGroup()
peerGroupStart = currentPosition;
// find end of peer group
peerGroupEnd = peerGroupStart + 1;
while ((peerGroupEnd < partitionEnd) && pagesIndex.positionEqualsPosition(peerGroupHashStrategy, peerGroupStart, peerGroupEnd)) {
while ((peerGroupEnd < partitionEnd) && pagesIndex.positionNotDistinctFromPosition(peerGroupHashStrategy, peerGroupStart, peerGroupEnd)) {
peerGroupEnd++;
}
}
Expand Down Expand Up @@ -334,7 +334,7 @@ private Range getFrameRange(FrameInfo frameInfo, Range recentRange, PagesIndexCo
frameInfo.getStartType() == CURRENT_ROW && frameInfo.getEndType() == UNBOUNDED_FOLLOWING ||
frameInfo.getStartType() == UNBOUNDED_PRECEDING && frameInfo.getEndType() == CURRENT_ROW) {
// same peer group as recent row
if (currentPosition == partitionStart || pagesIndex.positionEqualsPosition(peerGroupHashStrategy, currentPosition - 1, currentPosition)) {
if (currentPosition == partitionStart || pagesIndex.positionNotDistinctFromPosition(peerGroupHashStrategy, currentPosition - 1, currentPosition)) {
return recentRange;
}
// next peer group
Expand Down
Expand Up @@ -56,7 +56,6 @@
import static io.trino.operator.OperatorAssertion.toMaterializedResult;
import static io.trino.operator.OperatorAssertion.toPages;
import static io.trino.operator.WindowFunctionDefinition.window;
import static io.trino.spi.connector.SortOrder.ASC_NULLS_FIRST;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
Expand Down Expand Up @@ -308,46 +307,65 @@ public void testRowNumberArbitraryWithSpill()
assertOperatorEquals(operatorFactory, driverContext, input, expected);
}

@Test
public void testRank()
@Test(dataProvider = "spillEnabled")
public void testDistinctPartitionAndPeers(boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimit)
{
List<Page> input = rowPagesBuilder(BIGINT, DOUBLE)
.row(1L, null)
.row(2L, 0.2)
.row(2L, Double.NaN)
.row(3L, 0.1)
.row(3L, 0.91)
List<Page> input = rowPagesBuilder(DOUBLE, DOUBLE)
.row(1.0, 1.0)
.row(1.0, 0.0)
.row(1.0, Double.NaN)
.row(1.0, null)
.row(2.0, 2.0)
.row(2.0, Double.NaN)
.row(Double.NaN, Double.NaN)
.row(Double.NaN, Double.NaN)
.row(null, null)
.row(null, 1.0)
.row(null, null)
.pageBreak()
.row(1L, 0.4)
.pageBreak()
.row(1L, null)
.row(2L, 0.7)
.row(2L, Double.NaN)
.row(1.0, Double.NaN)
.row(1.0, null)
.row(2.0, 2.0)
.row(2.0, null)
.row(Double.NaN, 3.0)
.row(Double.NaN, null)
.row(null, 2.0)
.row(null, null)
.build();

WindowOperatorFactory operatorFactory = createFactoryUnbounded(
ImmutableList.of(BIGINT, DOUBLE),
Ints.asList(1, 0),
ImmutableList.of(DOUBLE, DOUBLE),
Ints.asList(0, 1),
RANK,
Ints.asList(0),
Ints.asList(1),
ImmutableList.copyOf(new SortOrder[] {ASC_NULLS_FIRST}),
false);
ImmutableList.copyOf(new SortOrder[] {SortOrder.ASC_NULLS_LAST}),
spillEnabled);

DriverContext driverContext = createDriverContext();
MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE, BIGINT, BIGINT)
.row(null, 1L, 1L)
.row(null, 1L, 1L)
.row(0.4, 1L, 3L)
.row(0.2, 2L, 1L)
.row(0.7, 2L, 2L)
.row(Double.NaN, 2L, 3L)
.row(Double.NaN, 2L, 4L)
.row(0.1, 3L, 1L)
.row(0.91, 3L, 2L)
DriverContext driverContext = createDriverContext(memoryLimit);
MaterializedResult expected = resultBuilder(driverContext.getSession(), DOUBLE, DOUBLE, BIGINT)
.row(1.0, 0.0, 1L)
.row(1.0, 1.0, 2L)
.row(1.0, Double.NaN, 3L)
.row(1.0, Double.NaN, 3L)
.row(1.0, null, 5L)
.row(1.0, null, 5L)
.row(2.0, 2.0, 1L)
.row(2.0, 2.0, 1L)
.row(2.0, Double.NaN, 3L)
.row(2.0, null, 4L)
.row(Double.NaN, 3.0, 1L)
.row(Double.NaN, Double.NaN, 2L)
.row(Double.NaN, Double.NaN, 2L)
.row(Double.NaN, null, 4L)
.row(null, 1.0, 1L)
.row(null, 2.0, 2L)
.row(null, null, 3L)
.row(null, null, 3L)
.row(null, null, 3L)
.build();

assertOperatorEquals(operatorFactory, driverContext, input, expected);
assertOperatorEquals(operatorFactory, driverContext, input, expected, revokeMemoryWhenAddingPages);
}

@Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of 10B.*")
Expand Down
Expand Up @@ -29,6 +29,70 @@
public abstract class AbstractTestWindowQueries
extends AbstractTestQueryFramework
{
@Test
public void testDistinctWindowPartitionAndPeerGroups()
{
MaterializedResult actual = computeActual("" +
"SELECT x, y, z, rank() OVER (PARTITION BY x ORDER BY y) rnk\n" +
"FROM (\n" +
" VALUES " +
" (1.0, 0.1, 'a'), " +
" (2.0, 0.1, 'a'), " +
" (nan(), 0.1, 'a'), " +
" (NULL, 0.1, 'a'), " +
" (1.0, 0.1, 'b'), " +
" (2.0, 0.1, 'b'), " +
" (nan(), 0.1, 'b'), " +
" (NULL, 0.1, 'b'), " +
" (1.0, nan(), 'a'), " +
" (2.0, nan(), 'a'), " +
" (nan(), nan(), 'a'), " +
" (NULL, nan(), 'a'), " +
" (1.0, nan(), 'b'), " +
" (2.0, nan(), 'b'), " +
" (nan(), nan(), 'b'), " +
" (NULL, nan(), 'b'), " +
" (1.0, NULL, 'a'), " +
" (2.0, NULL, 'a'), " +
" (nan(), NULL, 'a'), " +
" (NULL, NULL, 'a'), " +
" (1.0, NULL, 'b'), " +
" (2.0, NULL, 'b'), " +
" (nan(), NULL, 'b'), " +
" (NULL, NULL, 'b') " +
") a(x, y, z)" +
"ORDER BY x, y, z");

MaterializedResult expected = resultBuilder(getSession(), VARCHAR, VARCHAR, DOUBLE, BIGINT)
.row(1.0, 0.1, "a", 1L)
.row(1.0, 0.1, "b", 1L)
.row(1.0, Double.NaN, "a", 3L)
.row(1.0, Double.NaN, "b", 3L)
.row(1.0, null, "a", 5L)
.row(1.0, null, "b", 5L)
.row(2.0, 0.1, "a", 1L)
.row(2.0, 0.1, "b", 1L)
.row(2.0, Double.NaN, "a", 3L)
.row(2.0, Double.NaN, "b", 3L)
.row(2.0, null, "a", 5L)
.row(2.0, null, "b", 5L)
.row(Double.NaN, 0.1, "a", 1L)
.row(Double.NaN, 0.1, "b", 1L)
.row(Double.NaN, Double.NaN, "a", 3L)
.row(Double.NaN, Double.NaN, "b", 3L)
.row(Double.NaN, null, "a", 5L)
.row(Double.NaN, null, "b", 5L)
.row(null, 0.1, "a", 1L)
.row(null, 0.1, "b", 1L)
.row(null, Double.NaN, "a", 3L)
.row(null, Double.NaN, "b", 3L)
.row(null, null, "a", 5L)
.row(null, null, "b", 5L)
.build();

assertEquals(actual.getMaterializedRows(), expected.getMaterializedRows());
}

@Test
public void testRowFieldAccessorInWindowFunction()
{
Expand Down

0 comments on commit 93ea855

Please sign in to comment.