Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefilter join build side when it's too large #22667

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.benchmark;

import com.facebook.presto.testing.LocalQueryRunner;
import com.google.common.collect.ImmutableMap;

import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner;

public class SqlJoinPrefilterBuildsideBenchmark
extends AbstractSqlBenchmark
{
public SqlJoinPrefilterBuildsideBenchmark(LocalQueryRunner localQueryRunner)
{
super(localQueryRunner,
"join_prefilter_build_side",
4,
5,
"select t1.*, partkey from orders t1 left join lineitem t2 using (orderkey) where custkey < 10");
}

public static void main(String[] args)
{
new SqlJoinPrefilterBuildsideBenchmark(createLocalQueryRunner()).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out));
new SqlJoinPrefilterBuildsideBenchmark(createLocalQueryRunner(ImmutableMap.of("join_prefilter_build_side", "true"))).runBenchmark(new SimpleLineBenchmarkResultWriter(System.out));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6173,6 +6173,28 @@ public void testGroupByLimitPartitionKeys()
assertTrue(((String) plan.getOnlyValue()).toUpperCase().indexOf("MAP_AGG") >= 0);
}

@Test
public void testJoinPrefilterPartitionKeys()
{
Session prefilter = Session.builder(getSession())
.setSystemProperty("join_prefilter_build_side", "true")
.build();

@Language("SQL") String createTable = "" +
"CREATE TABLE join_prefilter_test " +
"WITH (" +
"partitioned_by = ARRAY[ 'orderstatus' ]" +
") " +
"AS " +
"SELECT custkey, orderkey, orderstatus FROM tpch.tiny.orders";

assertUpdate(prefilter, createTable, 15000);
MaterializedResult result = computeActual(prefilter, "explain(type distributed) select 1 from join_prefilter_test join customer using(custkey) where orderstatus='O'");
// Make sure the layout of the copied table matches the original
String plan = (String) result.getMaterializedRows().get(0).getField(0);
assertNotEquals(plan.lastIndexOf(":: [[\"O\"]]"), plan.indexOf(":: [[\"O\"]]"));
}

@Test
public void testAddTableConstraints()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ public final class SystemSessionProperties
public static final String NATIVE_EXECUTION_PROCESS_REUSE_ENABLED = "native_execution_process_reuse_enabled";
public static final String NATIVE_DEBUG_VALIDATE_OUTPUT_FROM_OPERATORS = "native_debug_validate_output_from_operators";
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1925,7 +1926,12 @@ public SystemSessionProperties(
featuresConfig.getDefaultViewSecurityMode(),
false,
value -> CreateView.Security.valueOf(((String) value).toUpperCase()),
CreateView.Security::name));
CreateView.Security::name),
booleanProperty(
JOIN_PREFILTER_BUILD_SIDE,
"Prefiltering the build/inner side of a join with keys from the other side",
false,
false));
}

public static boolean isSpoolingOutputBufferEnabled(Session session)
Expand Down Expand Up @@ -3207,4 +3213,9 @@ public static CreateView.Security getDefaultViewSecurityMode(Session session)
{
return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, CreateView.Security.class);
}

public static boolean isJoinPrefilterEnabled(Session session)
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
{
return session.getSystemProperty(JOIN_PREFILTER_BUILD_SIDE, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import com.facebook.presto.sql.planner.optimizations.HistoricalStatisticsEquivalentPlanMarkingOptimizer;
import com.facebook.presto.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion;
import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer;
import com.facebook.presto.sql.planner.optimizations.JoinPrefilter;
import com.facebook.presto.sql.planner.optimizations.KeyBasedSampler;
import com.facebook.presto.sql.planner.optimizations.LimitPushDown;
import com.facebook.presto.sql.planner.optimizations.LogicalCteOptimizer;
Expand Down Expand Up @@ -650,6 +651,8 @@ public PlanOptimizers(
.addAll(new InlineSqlFunctions(metadata, sqlParser).rules())
.build()));

builder.add(new JoinPrefilter(metadata));

builder.add(
new IterativeOptimizer(
metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.facebook.presto.metadata.TableLayout;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
Expand Down Expand Up @@ -341,14 +342,21 @@ private static TableScanNode cloneTableScan(TableScanNode scanNode, Session sess
List<VariableReferenceExpression> newOutputVariables = outputVariablesBuilder.build();
ImmutableMap<VariableReferenceExpression, ColumnHandle> newAssignments = assignmentsBuilder.build();

TableHandle oldTableHandle = scanNode.getTable();
TableHandle newTableHandle = new TableHandle(
oldTableHandle.getConnectorId(),
oldTableHandle.getConnectorHandle(),
oldTableHandle.getTransaction(),
oldTableHandle.getLayout());

return new TableScanNode(
scanNode.getSourceLocation(),
planNodeIdAllocator.getNextId(),
scanLayout.getNewTableHandle(),
newTableHandle,
newOutputVariables,
newAssignments,
scanNode.getTableConstraints(),
scanLayout.getPredicate(),
scanNode.getCurrentConstraint(),
scanNode.getEnforcedConstraint());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static java.util.Objects.requireNonNull;

public class JoinPrefilter
implements PlanOptimizer
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
{
private final Metadata metadata;
private boolean isEnabledForTesting;

public JoinPrefilter(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}

@Override
public void setEnabledForTesting(boolean isSet)
{
isEnabledForTesting = isSet;
}

@Override
public boolean isEnabled(Session session)
{
return isEnabledForTesting || isJoinPrefilterEnabled(session);
}

@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isEnabled(session)) {
Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator);
PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged());
}

return PlanOptimizerResult.optimizerResult(plan, false);
}

private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final Session session;
private final Metadata metadata;
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private boolean planChanged;

private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "functionAndTypeManager is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null");
}

@Override
public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
{
PlanNode left = node.getLeft();
PlanNode right = node.getRight();

PlanNode rewrittenLeft = rewriteWith(this, left);
PlanNode rewrittenRight = rewriteWith(this, right);
List<EquiJoinClause> equiJoinClause = node.getCriteria();

// We apply this for only left and inner join and the left side of the join is a simple scan and the join is on one key
if (equiJoinClause.size() == 1 &&
(node.getType() == LEFT || node.getType() == INNER) &&
isScanFilterProject(rewrittenLeft)) {
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get();
VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get();

// First create a SELECT DISTINCT leftKey FROM left
Map<VariableReferenceExpression, VariableReferenceExpression> leftVarMap = new HashMap();
PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap);
PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of());

// DISTINCT on the leftkey
PlanNode filteringSource = new AggregationNode(
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
leftKey.getSourceLocation(),
idAllocator.getNextId(),
projectNode,
ImmutableMap.of(),
singleGroupingSet(projectNode.getOutputVariables()),
projectNode.getOutputVariables(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty(),
Optional.empty());

// There should be only one output variable. Project that
filteringSource = projectExpressions(filteringSource, idAllocator, variableAllocator, ImmutableList.of(filteringSource.getOutputVariables().get(0)), ImmutableList.of());

// Now we add a semijoin as the right side
VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN);
SemiJoinNode semiJoinNode = new SemiJoinNode(
rightKey.getSourceLocation(),
idAllocator.getNextId(),
node.getStatsEquivalentPlanNode(),
rewrittenRight,
filteringSource,
rightKey,
filteringSource.getOutputVariables().get(0),
semiJoinOutput,
Optional.empty(),
Optional.empty(),
Optional.empty(),
ImmutableMap.of());

rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput);
}

if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) {
planChanged = true;
return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight));
}

return node;
}

public boolean isPlanChanged()
{
return planChanged;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED;
import static com.facebook.presto.SystemSessionProperties.GENERATE_DOMAIN_FILTERS;
import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT;
import static com.facebook.presto.SystemSessionProperties.JOIN_PREFILTER_BUILD_SIDE;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE;
Expand Down Expand Up @@ -7527,4 +7528,24 @@ public void testLambdaInAggregation()
assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id",
".*REDUCE_AGG only supports non-NULL literal as the initial value.*");
}

@Test
public void testJoinPrefilter()
{
// Orig
String testQuery = "SELECT 1 from region join nation using(regionkey)";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") == -1);
result = computeActual(testQuery);
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
assertTrue(result.getRowCount() == 25);
kaikalur marked this conversation as resolved.
Show resolved Hide resolved

// With feature
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") != -1);
result = computeActual(session, testQuery);
assertTrue(result.getRowCount() == 25);
kaikalur marked this conversation as resolved.
Show resolved Hide resolved
}
}
Loading