diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlJoinPrefilterBuildsideBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlJoinPrefilterBuildsideBenchmark.java new file mode 100644 index 000000000000..097fc2dff741 --- /dev/null +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/SqlJoinPrefilterBuildsideBenchmark.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchmark; + +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableMap; + +import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; + +public class SqlJoinPrefilterBuildsideBenchmark + extends AbstractSqlBenchmark +{ + public SqlJoinPrefilterBuildsideBenchmark(LocalQueryRunner localQueryRunner) + { + super(localQueryRunner, + "join_prefilter_build_side", + 4, + 5, + "select count(1) from part join lineitem using (partkey) where part.name like '%x%'"); + } + + public static void main(String[] args) + { + new SqlJoinPrefilterBuildsideBenchmark(createLocalQueryRunner()).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + new SqlJoinPrefilterBuildsideBenchmark(createLocalQueryRunner(ImmutableMap.of("join_prefilter_build_side", "true"))).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out)); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java index b853262748fd..650fec9aaa51 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java @@ -6173,6 +6173,28 @@ public void testGroupByLimitPartitionKeys() assertTrue(((String) plan.getOnlyValue()).toUpperCase().indexOf("MAP_AGG") >= 0); } + @Test + public void testJoinPrefilterPartitionKeys() + { + Session prefilter = Session.builder(getSession()) + .setSystemProperty("join_prefilter_build_side", "true") + .build(); + + @Language("SQL") String createTable = "" + + "CREATE TABLE join_prefilter_test " + + "WITH (" + + "partitioned_by = ARRAY[ 'orderstatus' ]" + + ") " + + "AS " + + "SELECT custkey, orderkey, orderstatus FROM tpch.tiny.orders"; + + assertUpdate(prefilter, createTable, 15000); + MaterializedResult result = computeActual(prefilter, "explain(type distributed) select 1 from join_prefilter_test join customer using(custkey) where orderstatus='O'"); + // Make sure the layout of the copied table matches the original + String plan = (String) result.getMaterializedRows().get(0).getField(0); + assertNotEquals(plan.lastIndexOf(":: [[\"O\"]]"), plan.indexOf(":: [[\"O\"]]")); + } + @Test public void testAddTableConstraints() { diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 01d3432042e4..641ba55ee014 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -345,6 +345,7 @@ public final class SystemSessionProperties public static final String NATIVE_EXECUTION_PROCESS_REUSE_ENABLED = "native_execution_process_reuse_enabled"; public static final String NATIVE_DEBUG_VALIDATE_OUTPUT_FROM_OPERATORS = "native_debug_validate_output_from_operators"; public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode"; + public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side"; private final List> sessionProperties; @@ -1925,7 +1926,12 @@ public SystemSessionProperties( featuresConfig.getDefaultViewSecurityMode(), false, value -> CreateView.Security.valueOf(((String) value).toUpperCase()), - CreateView.Security::name)); + CreateView.Security::name), + booleanProperty( + JOIN_PREFILTER_BUILD_SIDE, + "Prefiltering the build/inner side of a join with keys from the other side", + false, + false)); } public static boolean isSpoolingOutputBufferEnabled(Session session) @@ -3207,4 +3213,9 @@ public static CreateView.Security getDefaultViewSecurityMode(Session session) { return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, CreateView.Security.class); } + + public static boolean isJoinPrefilterEnabled(Session session) + { + return session.getSystemProperty(JOIN_PREFILTER_BUILD_SIDE, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 1b01d78e3987..131c68293a9a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -150,6 +150,7 @@ import com.facebook.presto.sql.planner.optimizations.HistoricalStatisticsEquivalentPlanMarkingOptimizer; import com.facebook.presto.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; +import com.facebook.presto.sql.planner.optimizations.JoinPrefilter; import com.facebook.presto.sql.planner.optimizations.KeyBasedSampler; import com.facebook.presto.sql.planner.optimizations.LimitPushDown; import com.facebook.presto.sql.planner.optimizations.LogicalCteOptimizer; @@ -650,6 +651,8 @@ public PlanOptimizers( .addAll(new InlineSqlFunctions(metadata, sqlParser).rules()) .build())); + builder.add(new JoinPrefilter(metadata)); + builder.add( new IterativeOptimizer( metadata, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 03711c4af7aa..18cc39bff431 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -26,6 +26,7 @@ import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -341,14 +342,21 @@ private static TableScanNode cloneTableScan(TableScanNode scanNode, Session sess List newOutputVariables = outputVariablesBuilder.build(); ImmutableMap newAssignments = assignmentsBuilder.build(); + TableHandle oldTableHandle = scanNode.getTable(); + TableHandle newTableHandle = new TableHandle( + oldTableHandle.getConnectorId(), + oldTableHandle.getConnectorHandle(), + oldTableHandle.getTransaction(), + oldTableHandle.getLayout()); + return new TableScanNode( scanNode.getSourceLocation(), planNodeIdAllocator.getNextId(), - scanLayout.getNewTableHandle(), + newTableHandle, newOutputVariables, newAssignments, scanNode.getTableConstraints(), - scanLayout.getPredicate(), + scanNode.getCurrentConstraint(), scanNode.getEnforcedConstraint()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java new file mode 100644 index 000000000000..f4eedf17efca --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.EquiJoinClause; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode; +import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject; +import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions; +import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; +import static java.util.Objects.requireNonNull; + +public class JoinPrefilter + implements PlanOptimizer +{ + private final Metadata metadata; + private boolean isEnabledForTesting; + + public JoinPrefilter(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public void setEnabledForTesting(boolean isSet) + { + isEnabledForTesting = isSet; + } + + @Override + public boolean isEnabled(Session session) + { + return isEnabledForTesting || isJoinPrefilterEnabled(session); + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + if (isEnabled(session)) { + Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator); + PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null); + return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged()); + } + + return PlanOptimizerResult.optimizerResult(plan, false); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final Session session; + private final Metadata metadata; + private final PlanNodeIdAllocator idAllocator; + private final VariableAllocator variableAllocator; + private boolean planChanged; + + private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator) + { + this.session = requireNonNull(session, "session is null"); + this.metadata = requireNonNull(metadata, "functionAndTypeManager is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null"); + } + + @Override + public PlanNode visitJoin(JoinNode node, RewriteContext context) + { + PlanNode left = node.getLeft(); + PlanNode right = node.getRight(); + + PlanNode rewrittenLeft = rewriteWith(this, left); + PlanNode rewrittenRight = rewriteWith(this, right); + List equiJoinClause = node.getCriteria(); + + // We apply this for only left and inner join and the right side of the join is a simple scan and the join is on one key + if (equiJoinClause.size() == 1 && + (node.getType() == LEFT || node.getType() == INNER) && + isScanFilterProject(rewrittenLeft)) { + VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get(); + VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get(); + + // First create a SELECT DISTINCT leftKey FROM left + Map leftVarMap = new HashMap(); + PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap); + PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of()); + + // DISTINCT on the leftkey + PlanNode filteringSource = new AggregationNode( + leftKey.getSourceLocation(), + idAllocator.getNextId(), + projectNode, + ImmutableMap.of(), + singleGroupingSet(projectNode.getOutputVariables()), + projectNode.getOutputVariables(), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + // There should be only one output variable. Project that + filteringSource = projectExpressions(filteringSource, idAllocator, variableAllocator, ImmutableList.of(filteringSource.getOutputVariables().get(0)), ImmutableList.of()); + + // Now we add a semijoin as the right side + VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN); + SemiJoinNode semiJoinNode = new SemiJoinNode( + rightKey.getSourceLocation(), + idAllocator.getNextId(), + node.getStatsEquivalentPlanNode(), + rewrittenRight, + filteringSource, + rightKey, + filteringSource.getOutputVariables().get(0), + semiJoinOutput, + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()); + + rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput); + } + + if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) { + planChanged = true; + return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight)); + } + + return node; + } + + public boolean isPlanChanged() + { + return planChanged; + } + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index f0d70338058d..6162d3cc10a0 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -55,6 +55,7 @@ import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED; import static com.facebook.presto.SystemSessionProperties.GENERATE_DOMAIN_FILTERS; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.JOIN_PREFILTER_BUILD_SIDE; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE; @@ -7527,4 +7528,24 @@ public void testLambdaInAggregation() assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", ".*REDUCE_AGG only supports non-NULL literal as the initial value.*"); } + + @Test + public void testJoinPrefilter() + { + // Orig + String testQuery = "SELECT 1 from region join nation using(regionkey)"; + MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") == -1); + result = computeActual(testQuery); + assertTrue(result.getRowCount() == 25); + + // With feature + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) + .build(); + result = computeActual(session, "explain(type distributed) " + testQuery); + assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") != -1); + result = computeActual(session, testQuery); + assertTrue(result.getRowCount() == 25); + } }