diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java index e83339673033d..87f625ba58417 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.Ranges; +import com.facebook.presto.spi.predicate.SortedRangeSet; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.predicate.ValueSet; import com.facebook.presto.spi.type.Type; @@ -44,6 +45,7 @@ import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.PeekingIterator; import com.google.common.math.DoubleMath; import javax.annotation.Nullable; @@ -72,6 +74,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Iterators.peekingIterator; import static java.math.RoundingMode.CEILING; import static java.math.RoundingMode.FLOOR; import static java.util.Objects.requireNonNull; @@ -124,54 +127,97 @@ private static Expression toPredicate(Domain domain, QualifiedNameReference refe return combineDisjunctsWithDefault(disjuncts, TRUE_LITERAL); } + private static Expression processRange(Type type, Range range, QualifiedNameReference reference) + { + if (range.isAll()) { + return TRUE_LITERAL; + } + + if (isBetween(range)) { + // specialize the range with BETWEEN expression if possible b/c it is currently more efficient + return new BetweenPredicate(reference, toExpression(range.getLow().getValue(), type), toExpression(range.getHigh().getValue(), type)); + } + + List rangeConjuncts = new ArrayList<>(); + if (!range.getLow().isLowerUnbounded()) { + switch (range.getLow().getBound()) { + case ABOVE: + rangeConjuncts.add(new ComparisonExpression(GREATER_THAN, reference, toExpression(range.getLow().getValue(), type))); + break; + case EXACTLY: + rangeConjuncts.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, reference, toExpression(range.getLow().getValue(), + type))); + break; + case BELOW: + throw new IllegalStateException("Low Marker should never use BELOW bound: " + range); + default: + throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); + } + } + if (!range.getHigh().isUpperUnbounded()) { + switch (range.getHigh().getBound()) { + case ABOVE: + throw new IllegalStateException("High Marker should never use ABOVE bound: " + range); + case EXACTLY: + rangeConjuncts.add(new ComparisonExpression(LESS_THAN_OR_EQUAL, reference, toExpression(range.getHigh().getValue(), type))); + break; + case BELOW: + rangeConjuncts.add(new ComparisonExpression(LESS_THAN, reference, toExpression(range.getHigh().getValue(), type))); + break; + default: + throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); + } + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + return combineConjuncts(rangeConjuncts); + } + + private static Expression combineRangeWithExcludedPoints(Type type, QualifiedNameReference reference, Range range, List excludedPoints) + { + if (excludedPoints.isEmpty()) { + return processRange(type, range, reference); + } + + Expression excludedPointsExpression = new NotExpression(new InPredicate(reference, new InListExpression(excludedPoints))); + if (excludedPoints.size() == 1) { + excludedPointsExpression = new ComparisonExpression(NOT_EQUAL, reference, getOnlyElement(excludedPoints)); + } + + return combineConjuncts(processRange(type, range, reference), excludedPointsExpression); + } + private static List extractDisjuncts(Type type, Ranges ranges, QualifiedNameReference reference) { List disjuncts = new ArrayList<>(); List singleValues = new ArrayList<>(); - for (Range range : ranges.getOrderedRanges()) { - checkState(!range.isAll()); // Already checked + List orderedRanges = ranges.getOrderedRanges(); + + SortedRangeSet sortedRangeSet = SortedRangeSet.copyOf(type, orderedRanges); + SortedRangeSet complement = sortedRangeSet.complement(); + + List singleValueExclusionsList = complement.getOrderedRanges().stream().filter(Range::isSingleValue).collect(toList()); + List originalUnionSingleValues = SortedRangeSet.copyOf(type, singleValueExclusionsList).union(sortedRangeSet).getOrderedRanges(); + PeekingIterator singleValueExclusions = peekingIterator(singleValueExclusionsList.iterator()); + + for (Range range : originalUnionSingleValues) { if (range.isSingleValue()) { singleValues.add(toExpression(range.getSingleValue(), type)); + continue; } - else if (isBetween(range)) { - // Specialize the range with BETWEEN expression if possible b/c it is currently more efficient - disjuncts.add(new BetweenPredicate(reference, toExpression(range.getLow().getValue(), type), toExpression(range.getHigh().getValue(), type))); + + // attempt to optimize ranges that can be coalesced as long as single value points are excluded + List singleValuesInRange = new ArrayList<>(); + while (singleValueExclusions.hasNext() && range.contains(singleValueExclusions.peek())) { + singleValuesInRange.add(toExpression(singleValueExclusions.next().getSingleValue(), type)); } - else { - List rangeConjuncts = new ArrayList<>(); - if (!range.getLow().isLowerUnbounded()) { - switch (range.getLow().getBound()) { - case ABOVE: - rangeConjuncts.add(new ComparisonExpression(GREATER_THAN, reference, toExpression(range.getLow().getValue(), type))); - break; - case EXACTLY: - rangeConjuncts.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, reference, toExpression(range.getLow().getValue(), - type))); - break; - case BELOW: - throw new IllegalStateException("Low Marker should never use BELOW bound: " + range); - default: - throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); - } - } - if (!range.getHigh().isUpperUnbounded()) { - switch (range.getHigh().getBound()) { - case ABOVE: - throw new IllegalStateException("High Marker should never use ABOVE bound: " + range); - case EXACTLY: - rangeConjuncts.add(new ComparisonExpression(LESS_THAN_OR_EQUAL, reference, toExpression(range.getHigh().getValue(), type))); - break; - case BELOW: - rangeConjuncts.add(new ComparisonExpression(LESS_THAN, reference, toExpression(range.getHigh().getValue(), type))); - break; - default: - throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); - } - } - // If rangeConjuncts is null, then the range was ALL, which should already have been checked for - checkState(!rangeConjuncts.isEmpty()); - disjuncts.add(combineConjuncts(rangeConjuncts)); + + if (!singleValuesInRange.isEmpty()) { + disjuncts.add(combineRangeWithExcludedPoints(type, reference, range, singleValuesInRange)); + continue; } + + disjuncts.add(processRange(type, range, reference)); } // Add back all of the possible single values either as an equality or an IN predicate diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java index 8eeda59ace4e6..6f93099f1af87 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDomainTranslator.java @@ -157,6 +157,60 @@ public void testRoundTrip() assertEquals(result.getTupleDomain(), tupleDomain); } + @Test + public void testInOptimization() + throws Exception + { + Domain testDomain = Domain.create( + ValueSet.all(BIGINT) + .subtract(ValueSet.ofRanges( + Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L))), false); + + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.builder().put(A, testDomain).build()); + assertEquals(toPredicate(tupleDomain), not(in(A, ImmutableList.of(1L, 2L, 3L)))); + + testDomain = Domain.create( + ValueSet.ofRanges( + Range.lessThan(BIGINT, 4L)).intersect( + ValueSet.all(BIGINT) + .subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L)))), false); + + tupleDomain = withColumnDomains(ImmutableMap.builder().put(A, testDomain).build()); + assertEquals(toPredicate(tupleDomain), and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L))))); + + testDomain = Domain.create(ValueSet.ofRanges( + Range.range(BIGINT, 1L, true, 3L, true), + Range.range(BIGINT, 5L, true, 7L, true), + Range.range(BIGINT, 9L, true, 11L, true)), + false); + + tupleDomain = withColumnDomains(ImmutableMap.builder().put(A, testDomain).build()); + assertEquals(toPredicate(tupleDomain), + or(between(A, longLiteral(1L), longLiteral(3L)), (between(A, longLiteral(5L), longLiteral(7L))), (between(A, longLiteral(9L), longLiteral(11L))))); + + testDomain = Domain.create( + ValueSet.ofRanges( + Range.lessThan(BIGINT, 4L)) + .intersect(ValueSet.all(BIGINT) + .subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L)))) + .union(ValueSet.ofRanges(Range.range(BIGINT, 7L, true, 9L, true))), false); + + tupleDomain = withColumnDomains(ImmutableMap.builder().put(A, testDomain).build()); + assertEquals(toPredicate(tupleDomain), or(and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L)))), between(A, longLiteral(7L), longLiteral(9L)))); + + testDomain = Domain.create( + ValueSet.ofRanges(Range.lessThan(BIGINT, 4L)) + .intersect(ValueSet.all(BIGINT) + .subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L)))) + .union(ValueSet.ofRanges(Range.range(BIGINT, 7L, false, 9L, false), Range.range(BIGINT, 11L, false, 13L, false))), false); + + tupleDomain = withColumnDomains(ImmutableMap.builder().put(A, testDomain).build()); + assertEquals(toPredicate(tupleDomain), or( + and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L)))), + and(greaterThan(A, longLiteral(7L)), lessThan(A, longLiteral(9L))), + and(greaterThan(A, longLiteral(11L)), lessThan(A, longLiteral(13L))))); + } + @Test public void testToPredicateNone() throws Exception diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 1c3dd74e81ed5..1003a50af8332 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -3608,6 +3608,7 @@ public void testIn() .mapToObj(Long::toString) .collect(joining(", ")); assertQuery("SELECT orderkey FROM orders WHERE orderkey IN (" + longValues + ")"); + assertQuery("SELECT orderkey FROM orders WHERE orderkey NOT IN (" + longValues + ")"); String arrayValues = range(0, 5000).asLongStream() .mapToObj(i -> format("ARRAY[%s, %s, %s]", i, i + 1, i + 2))