diff --git a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardPredicate.java b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardPredicate.java index 64bc27c0de9b..1f8057bad33c 100644 --- a/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardPredicate.java +++ b/presto-raptor/src/main/java/com/facebook/presto/raptor/metadata/ShardPredicate.java @@ -39,7 +39,6 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -48,6 +47,7 @@ class ShardPredicate private final String predicate; private final List types; private final List values; + private static final int MAX_RANGE_COUNT = 100; private ShardPredicate(String predicate, List types, List values) { @@ -108,50 +108,55 @@ public static ShardPredicate create(TupleDomain tupleDomain) continue; } + StringJoiner columnPredicate = new StringJoiner(" OR ", "(", ")").setEmptyValue("true"); Ranges ranges = domain.getValues().getRanges(); - // TODO: support multiple ranges - if (ranges.getRangeCount() != 1) { + // prevent generating complicated metadata queries + if (ranges.getRangeCount() > MAX_RANGE_COUNT) { continue; } - Range range = getOnlyElement(ranges.getOrderedRanges()); - Object minValue = null; - Object maxValue = null; - if (range.isSingleValue()) { - minValue = range.getSingleValue(); - maxValue = range.getSingleValue(); - } - else { - if (!range.getLow().isLowerUnbounded()) { - minValue = range.getLow().getValue(); + for (Range range : ranges.getOrderedRanges()) { + Object minValue = null; + Object maxValue = null; + if (range.isSingleValue()) { + minValue = range.getSingleValue(); + maxValue = range.getSingleValue(); } - if (!range.getHigh().isUpperUnbounded()) { - maxValue = range.getHigh().getValue(); + else { + if (!range.getLow().isLowerUnbounded()) { + minValue = range.getLow().getValue(); + } + if (!range.getHigh().isUpperUnbounded()) { + maxValue = range.getHigh().getValue(); + } } - } - String min; - String max; - if (handle.isBucketNumber()) { - min = "bucket_number"; - max = "bucket_number"; - } - else { - min = minColumn(handle.getColumnId()); - max = maxColumn(handle.getColumnId()); - } + String min; + String max; + if (handle.isBucketNumber()) { + min = "bucket_number"; + max = "bucket_number"; + } + else { + min = minColumn(handle.getColumnId()); + max = maxColumn(handle.getColumnId()); + } - if (minValue != null) { - predicate.add(format("(%s >= ? OR %s IS NULL)", max, max)); - types.add(jdbcType); - values.add(minValue); - } - if (maxValue != null) { - predicate.add(format("(%s <= ? OR %s IS NULL)", min, min)); - types.add(jdbcType); - values.add(maxValue); + StringJoiner rangePredicate = new StringJoiner(" AND ", "(", ")").setEmptyValue("true"); + if (minValue != null) { + rangePredicate.add(format("(%s >= ? OR %s IS NULL)", max, max)); + types.add(jdbcType); + values.add(minValue); + } + if (maxValue != null) { + rangePredicate.add(format("(%s <= ? OR %s IS NULL)", min, min)); + types.add(jdbcType); + values.add(maxValue); + } + columnPredicate.add(rangePredicate.toString()); } + predicate.add(columnPredicate.toString()); } return new ShardPredicate(predicate.toString(), types.build(), values.build()); } diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java index 06e0e17e5c1b..5b120573f57e 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestDatabaseShardManager.java @@ -71,6 +71,7 @@ import static com.facebook.presto.spi.predicate.Range.greaterThan; import static com.facebook.presto.spi.predicate.Range.greaterThanOrEqual; import static com.facebook.presto.spi.predicate.Range.lessThan; +import static com.facebook.presto.spi.predicate.Range.range; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DateType.DATE; @@ -570,10 +571,27 @@ public void testShardPruning() shardAssertion(tableId).range(c6, lessThan(BOOLEAN, true)).expected(shards); shardAssertion(tableId).range(c6, lessThan(BOOLEAN, false)).expected(shard1, shard3); - // TODO: support multiple ranges + // Test multiple ranges shardAssertion(tableId) .domain(c1, createDomain(lessThan(BIGINT, 0L), greaterThan(BIGINT, 25L))) - .expected(shards); + .expected(); + + shardAssertion(tableId) + .domain(c1, createDomain(range(BIGINT, 3L, true, 4L, true), range(BIGINT, 16L, true, 18L, true))) + .expected(shard2, shard3); + + shardAssertion(tableId) + .domain(c5, createDomain( + range(createVarcharType(10), utf8Slice("gum"), true, utf8Slice("happy"), true), + range(createVarcharType(10), utf8Slice("pear"), true, utf8Slice("wall"), true))) + .expected(shard1, shard3); + + shardAssertion(tableId) + .domain(c1, createDomain(range(BIGINT, 3L, true, 4L, true), range(BIGINT, 16L, true, 18L, true))) + .domain(c5, createDomain( + range(createVarcharType(10), utf8Slice("gum"), true, utf8Slice("happy"), true), + range(createVarcharType(10), utf8Slice("pear"), true, utf8Slice("wall"), true))) + .expected(shard3); } @Test diff --git a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardPredicate.java b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardPredicate.java index 30241e541c6e..e28edfaff7c6 100644 --- a/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardPredicate.java +++ b/presto-raptor/src/test/java/com/facebook/presto/raptor/metadata/TestShardPredicate.java @@ -98,13 +98,41 @@ public void testRangeShardUuidPredicate() } @Test - public void testBucketNumber() + public void testBucketNumberSingleRange() { TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( bucketNumberColumnHandle("test"), create(SortedRangeSet.copyOf(INTEGER, ImmutableList.of(equal(INTEGER, 1L))), false))); ShardPredicate shardPredicate = ShardPredicate.create(tupleDomain); - assertEquals(shardPredicate.getPredicate(), "(bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL)"); + assertEquals(shardPredicate.getPredicate(), "(((bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL)))"); + } + + @Test + public void testBucketNumberMultipleRanges() + { + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( + bucketNumberColumnHandle("test"), + create(SortedRangeSet.copyOf(INTEGER, ImmutableList.of(equal(INTEGER, 1L), equal(INTEGER, 3L))), false))); + + ShardPredicate shardPredicate = ShardPredicate.create(tupleDomain); + assertEquals(shardPredicate.getPredicate(), + "(((bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL))" + + " OR ((bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL)))"); + } + + @Test + public void testMultipleColumnsMultipleRanges() + { + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( + bucketNumberColumnHandle("test"), + create(SortedRangeSet.copyOf(INTEGER, ImmutableList.of(equal(INTEGER, 1L), equal(INTEGER, 3L))), false), + new RaptorColumnHandle("test", "col", 1, INTEGER), + create(SortedRangeSet.copyOf(INTEGER, ImmutableList.of(equal(INTEGER, 1L), equal(INTEGER, 3L))), false))); + ShardPredicate shardPredicate = ShardPredicate.create(tupleDomain); + assertEquals(shardPredicate.getPredicate(), "(((c1_max >= ? OR c1_max IS NULL) AND (c1_min <= ? OR c1_min IS NULL)) " + + "OR ((c1_max >= ? OR c1_max IS NULL) AND (c1_min <= ? OR c1_min IS NULL))) " + + "AND (((bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL)) " + + "OR ((bucket_number >= ? OR bucket_number IS NULL) AND (bucket_number <= ? OR bucket_number IS NULL)))"); } }