Skip to content

Commit

Permalink
Optimize join when build side is large
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalur committed May 5, 2024
1 parent 80ac016 commit 8cf9cdc
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6173,6 +6173,28 @@ public void testGroupByLimitPartitionKeys()
assertTrue(((String) plan.getOnlyValue()).toUpperCase().indexOf("MAP_AGG") >= 0);
}

@Test
public void testJoinPrefilterPartitionKeys()
{
Session prefilter = Session.builder(getSession())
.setSystemProperty("join_prefilter_enabled", "true")
.build();

@Language("SQL") String createTable = "" +
"CREATE TABLE join_prefilter_test " +
"WITH (" +
"partitioned_by = ARRAY[ 'orderstatus' ]" +
") " +
"AS " +
"SELECT custkey, orderkey, orderstatus FROM tpch.tiny.orders";

assertUpdate(prefilter, createTable, 15000);
MaterializedResult result = computeActual(prefilter, "explain(type distributed) select 1 from join_prefilter_test join customer using(custkey) where orderstatus='O'");
// Make sure the layout of the copied table matches the original
String plan = (String) result.getMaterializedRows().get(0).getField(0);
assertNotEquals(plan.lastIndexOf(":: [[\"O\"]]"), plan.indexOf(":: [[\"O\"]]"));
}

@Test
public void testAddTableConstraints()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ public final class SystemSessionProperties
public static final String NATIVE_EXECUTION_PROCESS_REUSE_ENABLED = "native_execution_process_reuse_enabled";
public static final String NATIVE_DEBUG_VALIDATE_OUTPUT_FROM_OPERATORS = "native_debug_validate_output_from_operators";
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_ENABLED = "join_prefilter_enabled";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1925,7 +1926,12 @@ public SystemSessionProperties(
featuresConfig.getDefaultViewSecurityMode(),
false,
value -> CreateView.Security.valueOf(((String) value).toUpperCase()),
CreateView.Security::name));
CreateView.Security::name),
booleanProperty(
JOIN_PREFILTER_ENABLED,
"Prefiltering the build/inner side of a join with keys from the other side",
false,
false));
}

public static boolean isSpoolingOutputBufferEnabled(Session session)
Expand Down Expand Up @@ -3207,4 +3213,9 @@ public static CreateView.Security getDefaultViewSecurityMode(Session session)
{
return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, CreateView.Security.class);
}

public static boolean isJoinPrefilterEnabled(Session session)
{
return session.getSystemProperty(JOIN_PREFILTER_ENABLED, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
import com.facebook.presto.sql.planner.optimizations.HistoricalStatisticsEquivalentPlanMarkingOptimizer;
import com.facebook.presto.sql.planner.optimizations.ImplementIntersectAndExceptAsUnion;
import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer;
import com.facebook.presto.sql.planner.optimizations.JoinPrefilter;
import com.facebook.presto.sql.planner.optimizations.KeyBasedSampler;
import com.facebook.presto.sql.planner.optimizations.LimitPushDown;
import com.facebook.presto.sql.planner.optimizations.LogicalCteOptimizer;
Expand Down Expand Up @@ -650,6 +651,8 @@ public PlanOptimizers(
.addAll(new InlineSqlFunctions(metadata, sqlParser).rules())
.build()));

builder.add(new JoinPrefilter(metadata));

builder.add(
new IterativeOptimizer(
metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.facebook.presto.metadata.TableLayout;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
Expand Down Expand Up @@ -341,14 +342,21 @@ private static TableScanNode cloneTableScan(TableScanNode scanNode, Session sess
List<VariableReferenceExpression> newOutputVariables = outputVariablesBuilder.build();
ImmutableMap<VariableReferenceExpression, ColumnHandle> newAssignments = assignmentsBuilder.build();

TableHandle oldTableHandle = scanNode.getTable();
TableHandle newTableHandle = new TableHandle(
oldTableHandle.getConnectorId(),
oldTableHandle.getConnectorHandle(),
oldTableHandle.getTransaction(),
oldTableHandle.getLayout());

return new TableScanNode(
scanNode.getSourceLocation(),
planNodeIdAllocator.getNextId(),
scanLayout.getNewTableHandle(),
newTableHandle,
newOutputVariables,
newAssignments,
scanNode.getTableConstraints(),
scanLayout.getPredicate(),
scanNode.getCurrentConstraint(),
scanNode.getEnforcedConstraint());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode;
import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject;
import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions;
import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren;
import static java.util.Objects.requireNonNull;

public class JoinPrefilter
implements PlanOptimizer
{
private final Metadata metadata;
private boolean isEnabledForTesting;

public JoinPrefilter(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}

@Override
public void setEnabledForTesting(boolean isSet)
{
isEnabledForTesting = isSet;
}

@Override
public boolean isEnabled(Session session)
{
return isEnabledForTesting || isJoinPrefilterEnabled(session);
}

@Override
public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
{
if (isEnabled(session)) {
Rewriter rewriter = new Rewriter(session, metadata, idAllocator, variableAllocator);
PlanNode rewritten = SimplePlanRewriter.rewriteWith(rewriter, plan, null);
return PlanOptimizerResult.optimizerResult(rewritten, rewriter.isPlanChanged());
}

return PlanOptimizerResult.optimizerResult(plan, false);
}

private static class Rewriter
extends SimplePlanRewriter<Void>
{
private final Session session;
private final Metadata metadata;
private final PlanNodeIdAllocator idAllocator;
private final VariableAllocator variableAllocator;
private boolean planChanged;

private Rewriter(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator)
{
this.session = requireNonNull(session, "session is null");
this.metadata = requireNonNull(metadata, "functionAndTypeManager is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "idAllocator is null");
}

@Override
public PlanNode visitJoin(JoinNode node, RewriteContext<Void> context)
{
PlanNode left = node.getLeft();
PlanNode right = node.getRight();

PlanNode rewrittenLeft = rewriteWith(this, left);
PlanNode rewrittenRight = rewriteWith(this, right);
List<EquiJoinClause> equiJoinClause = node.getCriteria();

// We apply this for only left and inner join and the right side of the join is a simple scan and the join is on one key
if (equiJoinClause.size() == 1 &&
(node.getType() == LEFT || node.getType() == INNER) &&
isScanFilterProject(rewrittenLeft)) {
VariableReferenceExpression leftKey = equiJoinClause.stream().map(x -> x.getLeft()).findFirst().get();
VariableReferenceExpression rightKey = equiJoinClause.stream().map(x -> x.getRight()).findFirst().get();

// First create a SELECT DISTINCT leftKey FROM left
Map<VariableReferenceExpression, VariableReferenceExpression> leftVarMap = new HashMap();
PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, ImmutableList.of(leftKey), leftVarMap);
PlanNode projectNode = projectExpressions(leftKeys, idAllocator, variableAllocator, ImmutableList.of(leftVarMap.get(leftKey)), ImmutableList.of());

// DISTINCT on the leftkey
PlanNode filteringSource = new AggregationNode(
leftKey.getSourceLocation(),
idAllocator.getNextId(),
projectNode,
ImmutableMap.of(),
singleGroupingSet(projectNode.getOutputVariables()),
projectNode.getOutputVariables(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty(),
Optional.empty());

// There should be only one output variable. Project that
filteringSource = projectExpressions(filteringSource, idAllocator, variableAllocator, ImmutableList.of(filteringSource.getOutputVariables().get(0)), ImmutableList.of());

// Now we add a semijoin as the right side
VariableReferenceExpression semiJoinOutput = variableAllocator.newVariable("semiJoinOutput", BOOLEAN);
SemiJoinNode semiJoinNode = new SemiJoinNode(
rightKey.getSourceLocation(),
idAllocator.getNextId(),
node.getStatsEquivalentPlanNode(),
rewrittenRight,
filteringSource,
rightKey,
filteringSource.getOutputVariables().get(0),
semiJoinOutput,
Optional.empty(),
Optional.empty(),
Optional.empty(),
ImmutableMap.of());

rewrittenRight = new FilterNode(semiJoinNode.getSourceLocation(), idAllocator.getNextId(), semiJoinNode, semiJoinOutput);
}

if (rewrittenLeft != node.getLeft() || rewrittenRight != node.getRight()) {
planChanged = true;
return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight));
}

return node;
}

public boolean isPlanChanged()
{
return planChanged;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED;
import static com.facebook.presto.SystemSessionProperties.GENERATE_DOMAIN_FILTERS;
import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT;
import static com.facebook.presto.SystemSessionProperties.JOIN_PREFILTER_ENABLED;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_PERCENTAGE;
Expand Down Expand Up @@ -7527,4 +7528,24 @@ public void testLambdaInAggregation()
assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id",
".*REDUCE_AGG only supports non-NULL literal as the initial value.*");
}

@Test
public void testJoinPrefilter()
{
// Orig
String testQuery = "SELECT 1 from region join nation using(regionkey)";
MaterializedResult result = computeActual("explain(type distributed) " + testQuery);
assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") == -1);
result = computeActual(testQuery);
assertTrue(result.getRowCount() == 25);

// With feature
Session session = Session.builder(getSession())
.setSystemProperty(JOIN_PREFILTER_ENABLED, String.valueOf(true))
.build();
result = computeActual(session, "explain(type distributed) " + testQuery);
assertTrue(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin") != -1);
result = computeActual(session, testQuery);
assertTrue(result.getRowCount() == 25);
}
}

0 comments on commit 8cf9cdc

Please sign in to comment.