Skip to content

Commit

Permalink
Remove cross join if one side of input is single row constant
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu committed Jun 26, 2024
1 parent 95da85b commit 8e24401
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ public final class SystemSessionProperties
public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters";
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache";
public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1917,6 +1918,11 @@ public SystemSessionProperties(
"get stats from a cache that was populated during query optimization rather than recalculating the stats on the final plan.",
featuresConfig.isPrintEstimatedStatsFromCache(),
false),
booleanProperty(
REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT,
"If one input of the cross join is a single row with constant value, remove this cross join and replace with a project node",
featuresConfig.isRemoveCrossJoinWithSingleConstantRow(),
false),
new PropertyMetadata<>(
DEFAULT_VIEW_SECURITY_MODE,
format("Set default view security mode. Options are: %s",
Expand Down Expand Up @@ -3225,6 +3231,11 @@ public static boolean isPrintEstimatedStatsFromCacheEnabled(Session session)
return session.getSystemProperty(PRINT_ESTIMATED_STATS_FROM_CACHE, Boolean.class);
}

public static boolean isRemoveCrossJoinWithConstantSingleRowInputEnabled(Session session)
{
return session.getSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, Boolean.class);
}

public static boolean shouldOptimizerUseHistograms(Session session)
{
return session.getSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ public class FeaturesConfig
private boolean limitNumberOfGroupsForKHyperLogLogAggregations = true;
private boolean generateDomainFilters;
private boolean printEstimatedStatsFromCache;
private boolean removeCrossJoinWithSingleConstantRow = true;
private CreateView.Security defaultViewSecurityMode = DEFINER;
private boolean useHistograms;

Expand Down Expand Up @@ -3107,6 +3108,19 @@ public FeaturesConfig setPrintEstimatedStatsFromCache(boolean printEstimatedStat
return this;
}

public boolean isRemoveCrossJoinWithSingleConstantRow()
{
return this.removeCrossJoinWithSingleConstantRow;
}

@Config("optimizer.remove-cross-join-with-single-constant-row")
@ConfigDescription("If one input of the cross join is a single row with constant value, remove this cross join and replace with a project node")
public FeaturesConfig setRemoveCrossJoinWithSingleConstantRow(boolean removeCrossJoinWithSingleConstantRow)
{
this.removeCrossJoinWithSingleConstantRow = removeCrossJoinWithSingleConstantRow;
return this;
}

public boolean isUseHistograms()
{
return useHistograms;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId;
import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.RemoveCrossJoinWithConstantInput;
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveIdentityProjectionsBelowProjection;
Expand Down Expand Up @@ -505,6 +506,13 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder().add(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager()))
.addAll(new RemoveMapCastRule(metadata.getFunctionAndTypeManager()).rules()).build()));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RemoveCrossJoinWithConstantInput())));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.JoinNode;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isRemoveCrossJoinWithConstantSingleRowInputEnabled;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

/**
* When one side of a cross join is one single row of constant, we can remove the cross join and replace it with a project.
* <pre>
* - Cross Join
* - table scan
* leftkey
* - values // only one row
* rightkey := 1
* </pre>
* into
* <pre>
* - project
* leftkey := leftkey
* rightkey := 1
* - table scan
* leftkey
* </pre>
*/
public class RemoveCrossJoinWithConstantInput
implements Rule<JoinNode>
{
@Override
public Pattern<JoinNode> getPattern()
{
return join().matching(x -> x.getType().equals(JoinType.INNER) && x.getCriteria().isEmpty());
}

@Override
public boolean isEnabled(Session session)
{
return isRemoveCrossJoinWithConstantSingleRowInputEnabled(session);
}

@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
PlanNode singleValueInput;
PlanNode joinInput;
if (outputSingleConstantRow(context.getLookup().resolve(node.getRight()), context)) {
singleValueInput = context.getLookup().resolve(node.getRight());
joinInput = context.getLookup().resolve(node.getLeft());
}
else if (outputSingleConstantRow((context.getLookup().resolve(node.getLeft())), context)) {
singleValueInput = context.getLookup().resolve(node.getLeft());
joinInput = context.getLookup().resolve(node.getRight());
}
else {
return Result.empty();
}
Map<VariableReferenceExpression, RowExpression> mapping = getConstantAssignments(singleValueInput, context);
return Result.ofPlanNode(addProjections(joinInput, context.getIdAllocator(), mapping));
}

private boolean outputSingleConstantRow(PlanNode planNode, Context context)
{
while (planNode instanceof ProjectNode) {
planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
}
if (planNode instanceof ValuesNode) {
return ((ValuesNode) planNode).getRows().size() == 1;
}
return false;
}

private Map<VariableReferenceExpression, RowExpression> getConstantAssignments(PlanNode planNode, Context context)
{
List<VariableReferenceExpression> outputVariables = planNode.getOutputVariables();
Map<VariableReferenceExpression, RowExpression> mapping = outputVariables.stream().collect(toImmutableMap(Function.identity(), Function.identity()));
while (planNode instanceof ProjectNode) {
Map<VariableReferenceExpression, RowExpression> assignments = ((ProjectNode) planNode).getAssignments().getMap();
mapping = mapping.entrySet().stream().collect(toImmutableMap(
Map.Entry::getKey, entry -> entry.getValue() instanceof VariableReferenceExpression ? assignments.get(entry.getValue()) : entry.getValue()));
planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
}
checkState(planNode instanceof ValuesNode);
ValuesNode valuesNode = (ValuesNode) planNode;
if (!valuesNode.getOutputVariables().isEmpty()) {
Map<VariableReferenceExpression, RowExpression> assignments = IntStream.range(0, valuesNode.getOutputVariables().size()).boxed()
.collect(toImmutableMap(idx -> valuesNode.getOutputVariables().get(idx), idx -> valuesNode.getRows().get(0).get(idx)));
mapping = mapping.entrySet().stream().collect(toImmutableMap(
Map.Entry::getKey, entry -> entry.getValue() instanceof VariableReferenceExpression ? assignments.get(entry.getValue()) : entry.getValue()));
}
checkState(mapping.entrySet().stream().allMatch(entry -> entry.getValue() instanceof ConstantExpression));
return mapping;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ public void testDefaults()
.setCteHeuristicReplicationThreshold(4)
.setLegacyJsonCast(true)
.setPrintEstimatedStatsFromCache(false)
.setRemoveCrossJoinWithSingleConstantRow(true)
.setUseHistograms(false)
.setUseNewNanDefinition(true));
}
Expand Down Expand Up @@ -487,6 +488,7 @@ public void testExplicitPropertyMappings()
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
.put("optimizer.print-estimated-stats-from-cache", "true")
.put("optimizer.remove-cross-join-with-single-constant-row", "false")
.put("optimizer.use-histograms", "true")
.put("use-new-nan-definition", "false")
.build();
Expand Down Expand Up @@ -701,6 +703,7 @@ public void testExplicitPropertyMappings()
.setCteHeuristicReplicationThreshold(2)
.setLegacyJsonCast(false)
.setPrintEstimatedStatsFromCache(true)
.setRemoveCrossJoinWithSingleConstantRow(false)
.setUseHistograms(true)
.setUseNewNanDefinition(false);
assertFullMapping(properties, expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT;
import static com.facebook.presto.SystemSessionProperties.REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
Expand All @@ -54,13 +55,27 @@ public void testJoin()
// This optimization will optimize out the projection below, hence disable it
Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "false")
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "false")
.build(),
anyTree(
join(INNER, ImmutableList.of(), Optional.empty(),
project(
ImmutableMap.of("X", expression("BIGINT '1'")),
values(ImmutableMap.of())),
values(ImmutableMap.of()))));

assertPlan(
"SELECT *\n" +
"FROM (\n" +
" SELECT EXTRACT(DAY FROM DATE '2017-01-01')\n" +
") t\n" +
"CROSS JOIN (VALUES 1)",
// This optimization will optimize out the projection below, hence disable it
Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "false")
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.build(),
anyTree(values()));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION;
import static com.facebook.presto.SystemSessionProperties.PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID;
import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT;
import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT;
import static com.facebook.presto.SystemSessionProperties.TASK_CONCURRENCY;
import static com.facebook.presto.SystemSessionProperties.getMaxLeafNodesInPlan;
Expand Down Expand Up @@ -1080,6 +1081,14 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
.setSystemProperty(OPTIMIZE_HASH_GENERATION, Boolean.toString(false))
.build();

Session disableRemoveCrossJoin = Session.builder(broadcastJoin)
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "false")
.build();

Session enableRemoveCrossJoin = Session.builder(broadcastJoin)
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.build();

// replicated join with naturally partitioned and distributed probe side is rewritten to partitioned join
assertPlanWithSession(
"SELECT r1.regionkey FROM (SELECT regionkey FROM region GROUP BY regionkey) r1, region r2 WHERE r2.regionkey = r1.regionkey",
Expand All @@ -1106,7 +1115,7 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
// replicated join is preserved if probe side is single node
assertPlanWithSession(
"SELECT * FROM (SELECT * FROM (VALUES 1) t(a)) t, region r WHERE r.regionkey = t.a",
broadcastJoin,
disableRemoveCrossJoin,
false,
anyTree(
node(JoinNode.class,
Expand All @@ -1116,6 +1125,12 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
exchange(REMOTE_STREAMING, GATHER,
node(TableScanNode.class))))));

assertPlanWithSession(
"SELECT * FROM (SELECT * FROM (VALUES 1) t(a)) t, region r WHERE r.regionkey = t.a",
enableRemoveCrossJoin,
false,
anyTree(node(TableScanNode.class)));

// replicated join is preserved if there are no equality criteria
assertPlanWithSession(
"SELECT * FROM (SELECT regionkey FROM region GROUP BY regionkey) r1, region r2 WHERE r2.regionkey > r1.regionkey",
Expand Down
Loading

0 comments on commit 8e24401

Please sign in to comment.