Skip to content

Commit

Permalink
Optimize some TupleDomains to produce fewer expressions
Browse files Browse the repository at this point in the history
Optimizations convert NOT IN queries to discontinuous ranges. For
example, NOT IN (1, 2, 3) turns into <1 || (>1 && <2) || >3. Queries
with a large number of values in the IN list become slow (or in some
cases generated classes that were too large) as a result of this.

Added an optimization to convert ranges into expressions with NOT INs
if doing so would produce fewer expressions.
  • Loading branch information
Raghav Sethi committed Feb 29, 2016
1 parent 9178d08 commit 43adc41
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 39 deletions.
Expand Up @@ -21,6 +21,7 @@
import com.facebook.presto.spi.predicate.NullableValue;
import com.facebook.presto.spi.predicate.Range;
import com.facebook.presto.spi.predicate.Ranges;
import com.facebook.presto.spi.predicate.SortedRangeSet;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.predicate.ValueSet;
import com.facebook.presto.spi.type.Type;
Expand All @@ -44,6 +45,7 @@
import com.facebook.presto.type.TypeRegistry;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.PeekingIterator;
import com.google.common.math.DoubleMath;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -72,6 +74,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Iterators.peekingIterator;
import static java.math.RoundingMode.CEILING;
import static java.math.RoundingMode.FLOOR;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -124,54 +127,97 @@ private static Expression toPredicate(Domain domain, QualifiedNameReference refe
return combineDisjunctsWithDefault(disjuncts, TRUE_LITERAL);
}

private static Expression processRange(Type type, Range range, QualifiedNameReference reference)
{
if (range.isAll()) {
return TRUE_LITERAL;
}

if (isBetween(range)) {
// specialize the range with BETWEEN expression if possible b/c it is currently more efficient
return new BetweenPredicate(reference, toExpression(range.getLow().getValue(), type), toExpression(range.getHigh().getValue(), type));
}

List<Expression> rangeConjuncts = new ArrayList<>();
if (!range.getLow().isLowerUnbounded()) {
switch (range.getLow().getBound()) {
case ABOVE:
rangeConjuncts.add(new ComparisonExpression(GREATER_THAN, reference, toExpression(range.getLow().getValue(), type)));
break;
case EXACTLY:
rangeConjuncts.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, reference, toExpression(range.getLow().getValue(),
type)));
break;
case BELOW:
throw new IllegalStateException("Low Marker should never use BELOW bound: " + range);
default:
throw new AssertionError("Unhandled bound: " + range.getLow().getBound());
}
}
if (!range.getHigh().isUpperUnbounded()) {
switch (range.getHigh().getBound()) {
case ABOVE:
throw new IllegalStateException("High Marker should never use ABOVE bound: " + range);
case EXACTLY:
rangeConjuncts.add(new ComparisonExpression(LESS_THAN_OR_EQUAL, reference, toExpression(range.getHigh().getValue(), type)));
break;
case BELOW:
rangeConjuncts.add(new ComparisonExpression(LESS_THAN, reference, toExpression(range.getHigh().getValue(), type)));
break;
default:
throw new AssertionError("Unhandled bound: " + range.getHigh().getBound());
}
}
// If rangeConjuncts is null, then the range was ALL, which should already have been checked for
checkState(!rangeConjuncts.isEmpty());
return combineConjuncts(rangeConjuncts);
}

private static Expression combineRangeWithExcludedPoints(Type type, QualifiedNameReference reference, Range range, List<Expression> excludedPoints)
{
if (excludedPoints.isEmpty()) {
return processRange(type, range, reference);
}

Expression excludedPointsExpression = new NotExpression(new InPredicate(reference, new InListExpression(excludedPoints)));
if (excludedPoints.size() == 1) {
excludedPointsExpression = new ComparisonExpression(NOT_EQUAL, reference, getOnlyElement(excludedPoints));
}

return combineConjuncts(processRange(type, range, reference), excludedPointsExpression);
}

private static List<Expression> extractDisjuncts(Type type, Ranges ranges, QualifiedNameReference reference)
{
List<Expression> disjuncts = new ArrayList<>();
List<Expression> singleValues = new ArrayList<>();
for (Range range : ranges.getOrderedRanges()) {
checkState(!range.isAll()); // Already checked
List<Range> orderedRanges = ranges.getOrderedRanges();

SortedRangeSet sortedRangeSet = SortedRangeSet.copyOf(type, orderedRanges);
SortedRangeSet complement = sortedRangeSet.complement();

List<Range> singleValueExclusionsList = complement.getOrderedRanges().stream().filter(Range::isSingleValue).collect(toList());
List<Range> originalUnionSingleValues = SortedRangeSet.copyOf(type, singleValueExclusionsList).union(sortedRangeSet).getOrderedRanges();
PeekingIterator<Range> singleValueExclusions = peekingIterator(singleValueExclusionsList.iterator());

for (Range range : originalUnionSingleValues) {
if (range.isSingleValue()) {
singleValues.add(toExpression(range.getSingleValue(), type));
continue;
}
else if (isBetween(range)) {
// Specialize the range with BETWEEN expression if possible b/c it is currently more efficient
disjuncts.add(new BetweenPredicate(reference, toExpression(range.getLow().getValue(), type), toExpression(range.getHigh().getValue(), type)));

// attempt to optimize ranges that can be coalesced as long as single value points are excluded
List<Expression> singleValuesInRange = new ArrayList<>();
while (singleValueExclusions.hasNext() && range.contains(singleValueExclusions.peek())) {
singleValuesInRange.add(toExpression(singleValueExclusions.next().getSingleValue(), type));
}
else {
List<Expression> rangeConjuncts = new ArrayList<>();
if (!range.getLow().isLowerUnbounded()) {
switch (range.getLow().getBound()) {
case ABOVE:
rangeConjuncts.add(new ComparisonExpression(GREATER_THAN, reference, toExpression(range.getLow().getValue(), type)));
break;
case EXACTLY:
rangeConjuncts.add(new ComparisonExpression(GREATER_THAN_OR_EQUAL, reference, toExpression(range.getLow().getValue(),
type)));
break;
case BELOW:
throw new IllegalStateException("Low Marker should never use BELOW bound: " + range);
default:
throw new AssertionError("Unhandled bound: " + range.getLow().getBound());
}
}
if (!range.getHigh().isUpperUnbounded()) {
switch (range.getHigh().getBound()) {
case ABOVE:
throw new IllegalStateException("High Marker should never use ABOVE bound: " + range);
case EXACTLY:
rangeConjuncts.add(new ComparisonExpression(LESS_THAN_OR_EQUAL, reference, toExpression(range.getHigh().getValue(), type)));
break;
case BELOW:
rangeConjuncts.add(new ComparisonExpression(LESS_THAN, reference, toExpression(range.getHigh().getValue(), type)));
break;
default:
throw new AssertionError("Unhandled bound: " + range.getHigh().getBound());
}
}
// If rangeConjuncts is null, then the range was ALL, which should already have been checked for
checkState(!rangeConjuncts.isEmpty());
disjuncts.add(combineConjuncts(rangeConjuncts));

if (!singleValuesInRange.isEmpty()) {
disjuncts.add(combineRangeWithExcludedPoints(type, reference, range, singleValuesInRange));
continue;
}

disjuncts.add(processRange(type, range, reference));
}

// Add back all of the possible single values either as an equality or an IN predicate
Expand Down
Expand Up @@ -157,6 +157,60 @@ public void testRoundTrip()
assertEquals(result.getTupleDomain(), tupleDomain);
}

@Test
public void testInOptimization()
throws Exception
{
Domain testDomain = Domain.create(
ValueSet.all(BIGINT)
.subtract(ValueSet.ofRanges(
Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L))), false);

TupleDomain<Symbol> tupleDomain = withColumnDomains(ImmutableMap.<Symbol, Domain>builder().put(A, testDomain).build());
assertEquals(toPredicate(tupleDomain), not(in(A, ImmutableList.of(1L, 2L, 3L))));

testDomain = Domain.create(
ValueSet.ofRanges(
Range.lessThan(BIGINT, 4L)).intersect(
ValueSet.all(BIGINT)
.subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L)))), false);

tupleDomain = withColumnDomains(ImmutableMap.<Symbol, Domain>builder().put(A, testDomain).build());
assertEquals(toPredicate(tupleDomain), and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L)))));

testDomain = Domain.create(ValueSet.ofRanges(
Range.range(BIGINT, 1L, true, 3L, true),
Range.range(BIGINT, 5L, true, 7L, true),
Range.range(BIGINT, 9L, true, 11L, true)),
false);

tupleDomain = withColumnDomains(ImmutableMap.<Symbol, Domain>builder().put(A, testDomain).build());
assertEquals(toPredicate(tupleDomain),
or(between(A, longLiteral(1L), longLiteral(3L)), (between(A, longLiteral(5L), longLiteral(7L))), (between(A, longLiteral(9L), longLiteral(11L)))));

testDomain = Domain.create(
ValueSet.ofRanges(
Range.lessThan(BIGINT, 4L))
.intersect(ValueSet.all(BIGINT)
.subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L))))
.union(ValueSet.ofRanges(Range.range(BIGINT, 7L, true, 9L, true))), false);

tupleDomain = withColumnDomains(ImmutableMap.<Symbol, Domain>builder().put(A, testDomain).build());
assertEquals(toPredicate(tupleDomain), or(and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L)))), between(A, longLiteral(7L), longLiteral(9L))));

testDomain = Domain.create(
ValueSet.ofRanges(Range.lessThan(BIGINT, 4L))
.intersect(ValueSet.all(BIGINT)
.subtract(ValueSet.ofRanges(Range.equal(BIGINT, 1L), Range.equal(BIGINT, 2L), Range.equal(BIGINT, 3L))))
.union(ValueSet.ofRanges(Range.range(BIGINT, 7L, false, 9L, false), Range.range(BIGINT, 11L, false, 13L, false))), false);

tupleDomain = withColumnDomains(ImmutableMap.<Symbol, Domain>builder().put(A, testDomain).build());
assertEquals(toPredicate(tupleDomain), or(
and(lessThan(A, longLiteral(4L)), not(in(A, ImmutableList.of(1L, 2L, 3L)))),
and(greaterThan(A, longLiteral(7L)), lessThan(A, longLiteral(9L))),
and(greaterThan(A, longLiteral(11L)), lessThan(A, longLiteral(13L)))));
}

@Test
public void testToPredicateNone()
throws Exception
Expand Down
Expand Up @@ -3608,6 +3608,7 @@ public void testIn()
.mapToObj(Long::toString)
.collect(joining(", "));
assertQuery("SELECT orderkey FROM orders WHERE orderkey IN (" + longValues + ")");
assertQuery("SELECT orderkey FROM orders WHERE orderkey NOT IN (" + longValues + ")");

String arrayValues = range(0, 5000).asLongStream()
.mapToObj(i -> format("ARRAY[%s, %s, %s]", i, i + 1, i + 2))
Expand Down

0 comments on commit 43adc41

Please sign in to comment.