Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cross join if one side of input is single row constant #23081

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ public final class SystemSessionProperties
public static final String GENERATE_DOMAIN_FILTERS = "generate_domain_filters";
public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression";
public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache";
public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1917,6 +1918,11 @@ public SystemSessionProperties(
"get stats from a cache that was populated during query optimization rather than recalculating the stats on the final plan.",
featuresConfig.isPrintEstimatedStatsFromCache(),
false),
booleanProperty(
REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT,
"If one input of the cross join is a single row with constant value, remove this cross join and replace with a project node",
featuresConfig.isRemoveCrossJoinWithSingleConstantRow(),
false),
new PropertyMetadata<>(
DEFAULT_VIEW_SECURITY_MODE,
format("Set default view security mode. Options are: %s",
Expand Down Expand Up @@ -3225,6 +3231,11 @@ public static boolean isPrintEstimatedStatsFromCacheEnabled(Session session)
return session.getSystemProperty(PRINT_ESTIMATED_STATS_FROM_CACHE, Boolean.class);
}

public static boolean isRemoveCrossJoinWithConstantSingleRowInputEnabled(Session session)
{
return session.getSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, Boolean.class);
}

public static boolean shouldOptimizerUseHistograms(Session session)
{
return session.getSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ public class FeaturesConfig
private boolean limitNumberOfGroupsForKHyperLogLogAggregations = true;
private boolean generateDomainFilters;
private boolean printEstimatedStatsFromCache;
private boolean removeCrossJoinWithSingleConstantRow = true;
private CreateView.Security defaultViewSecurityMode = DEFINER;
private boolean useHistograms;

Expand Down Expand Up @@ -3107,6 +3108,19 @@ public FeaturesConfig setPrintEstimatedStatsFromCache(boolean printEstimatedStat
return this;
}

public boolean isRemoveCrossJoinWithSingleConstantRow()
{
return this.removeCrossJoinWithSingleConstantRow;
}

@Config("optimizer.remove-cross-join-with-single-constant-row")
@ConfigDescription("If one input of the cross join is a single row with constant value, remove this cross join and replace with a project node")
public FeaturesConfig setRemoveCrossJoinWithSingleConstantRow(boolean removeCrossJoinWithSingleConstantRow)
{
this.removeCrossJoinWithSingleConstantRow = removeCrossJoinWithSingleConstantRow;
return this;
}

public boolean isUseHistograms()
{
return useHistograms;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId;
import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion;
import com.facebook.presto.sql.planner.iterative.rule.RemoveCrossJoinWithConstantInput;
import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete;
import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample;
import com.facebook.presto.sql.planner.iterative.rule.RemoveIdentityProjectionsBelowProjection;
Expand Down Expand Up @@ -505,6 +506,13 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder().add(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager()))
.addAll(new RemoveMapCastRule(metadata.getFunctionAndTypeManager()).rules()).build()));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RemoveCrossJoinWithConstantInput(metadata.getFunctionAndTypeManager()))));

builder.add(new IterativeOptimizer(
metadata,
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isRemoveCrossJoinWithConstantSingleRowInputEnabled;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

/**
* When one side of a cross join is one single row of constant, we can remove the cross join and replace it with a project.
* <pre>
* - Cross Join
* - table scan
* left_field
* - values // only one row
* right_field := 1
* </pre>
* into
* <pre>
* - project
* left_field := left_field
* right_field := 1
* - table scan
* left_field
* </pre>
*/
public class RemoveCrossJoinWithConstantInput
implements Rule<JoinNode>
{
private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;

public RemoveCrossJoinWithConstantInput(FunctionAndTypeManager functionAndTypeManager)
{
this.rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
}

@Override
public Pattern<JoinNode> getPattern()
{
return join().matching(x -> x.getType().equals(JoinType.INNER) && x.getCriteria().isEmpty());
feilong-liu marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public boolean isEnabled(Session session)
{
return isRemoveCrossJoinWithConstantSingleRowInputEnabled(session);
}

@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
PlanNode singleValueInput;
PlanNode joinInput;
PlanNode leftInput = context.getLookup().resolve(node.getLeft());
PlanNode rightInput = context.getLookup().resolve(node.getRight());
if (isOutputSingleConstantRow(rightInput, context)) {
singleValueInput = rightInput;
joinInput = leftInput;
}
else if (isOutputSingleConstantRow(leftInput, context)) {
singleValueInput = leftInput;
joinInput = rightInput;
}
else {
return Result.empty();
}
Optional<Map<VariableReferenceExpression, RowExpression>> mapping = getConstantAssignments(singleValueInput, context);
if (!mapping.isPresent()) {
return Result.empty();
}
PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get());
if (node.getFilter().isPresent()) {
resultNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), resultNode, node.getFilter().get());
}
return Result.ofPlanNode(resultNode);
}

private boolean isOutputSingleConstantRow(PlanNode planNode, Context context)
{
while (planNode instanceof ProjectNode) {
planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
}
if (planNode instanceof ValuesNode) {
return ((ValuesNode) planNode).getRows().size() == 1;
}
return false;
}

private Optional<Map<VariableReferenceExpression, RowExpression>> getConstantAssignments(PlanNode planNode, Context context)
{
List<VariableReferenceExpression> outputVariables = planNode.getOutputVariables();
Map<VariableReferenceExpression, RowExpression> mapping = outputVariables.stream().collect(toImmutableMap(Function.identity(), Function.identity()));
while (planNode instanceof ProjectNode) {
Map<VariableReferenceExpression, RowExpression> assignments = ((ProjectNode) planNode).getAssignments().getMap();
mapping = updateAssignments(mapping, assignments);
planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
}

checkState(planNode instanceof ValuesNode);
ValuesNode valuesNode = (ValuesNode) planNode;
if (!valuesNode.getOutputVariables().isEmpty()) {
Map<VariableReferenceExpression, RowExpression> assignments = IntStream.range(0, valuesNode.getOutputVariables().size()).boxed()
.collect(toImmutableMap(idx -> valuesNode.getOutputVariables().get(idx), idx -> valuesNode.getRows().get(0).get(idx)));
mapping = updateAssignments(mapping, assignments);
}
boolean allDeterministic = mapping.values().stream().allMatch(rowExpressionDeterminismEvaluator::isDeterministic);
if (allDeterministic) {
return Optional.of(mapping);
}
return Optional.empty();
}

private static Map<VariableReferenceExpression, RowExpression> updateAssignments(Map<VariableReferenceExpression, RowExpression> mapping, Map<VariableReferenceExpression, RowExpression> newAssignments)
{
return mapping.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue() instanceof VariableReferenceExpression ? newAssignments.get(entry.getValue()) : entry.getValue()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ public void testDefaults()
.setCteHeuristicReplicationThreshold(4)
.setLegacyJsonCast(true)
.setPrintEstimatedStatsFromCache(false)
.setRemoveCrossJoinWithSingleConstantRow(true)
.setUseHistograms(false)
.setUseNewNanDefinition(true));
}
Expand Down Expand Up @@ -487,6 +488,7 @@ public void testExplicitPropertyMappings()
.put("default-view-security-mode", INVOKER.name())
.put("cte-heuristic-replication-threshold", "2")
.put("optimizer.print-estimated-stats-from-cache", "true")
.put("optimizer.remove-cross-join-with-single-constant-row", "false")
.put("optimizer.use-histograms", "true")
.put("use-new-nan-definition", "false")
.build();
Expand Down Expand Up @@ -701,6 +703,7 @@ public void testExplicitPropertyMappings()
.setCteHeuristicReplicationThreshold(2)
.setLegacyJsonCast(false)
.setPrintEstimatedStatsFromCache(true)
.setRemoveCrossJoinWithSingleConstantRow(false)
.setUseHistograms(true)
.setUseNewNanDefinition(false);
assertFullMapping(properties, expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT;
import static com.facebook.presto.SystemSessionProperties.REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
Expand All @@ -54,13 +55,27 @@ public void testJoin()
// This optimization will optimize out the projection below, hence disable it
Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "false")
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "false")
.build(),
anyTree(
join(INNER, ImmutableList.of(), Optional.empty(),
project(
ImmutableMap.of("X", expression("BIGINT '1'")),
values(ImmutableMap.of())),
values(ImmutableMap.of()))));

assertPlan(
"SELECT *\n" +
"FROM (\n" +
" SELECT EXTRACT(DAY FROM DATE '2017-01-01')\n" +
") t\n" +
"CROSS JOIN (VALUES 1)",
// This optimization will optimize out the projection below, hence disable it
Session.builder(this.getQueryRunner().getDefaultSession())
.setSystemProperty(REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "false")
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.build(),
anyTree(values()));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION;
import static com.facebook.presto.SystemSessionProperties.PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID;
import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT;
import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT;
import static com.facebook.presto.SystemSessionProperties.TASK_CONCURRENCY;
import static com.facebook.presto.SystemSessionProperties.getMaxLeafNodesInPlan;
Expand Down Expand Up @@ -1080,6 +1081,14 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
.setSystemProperty(OPTIMIZE_HASH_GENERATION, Boolean.toString(false))
.build();

Session disableRemoveCrossJoin = Session.builder(broadcastJoin)
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "false")
.build();

Session enableRemoveCrossJoin = Session.builder(broadcastJoin)
.setSystemProperty(REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT, "true")
.build();

// replicated join with naturally partitioned and distributed probe side is rewritten to partitioned join
assertPlanWithSession(
"SELECT r1.regionkey FROM (SELECT regionkey FROM region GROUP BY regionkey) r1, region r2 WHERE r2.regionkey = r1.regionkey",
Expand All @@ -1106,7 +1115,7 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
// replicated join is preserved if probe side is single node
assertPlanWithSession(
"SELECT * FROM (SELECT * FROM (VALUES 1) t(a)) t, region r WHERE r.regionkey = t.a",
broadcastJoin,
disableRemoveCrossJoin,
false,
anyTree(
node(JoinNode.class,
Expand All @@ -1116,6 +1125,12 @@ public void testUsesDistributedJoinIfNaturallyPartitionedOnProbeSymbols()
exchange(REMOTE_STREAMING, GATHER,
node(TableScanNode.class))))));

assertPlanWithSession(
"SELECT * FROM (SELECT * FROM (VALUES 1) t(a)) t, region r WHERE r.regionkey = t.a",
enableRemoveCrossJoin,
false,
anyTree(node(TableScanNode.class)));

// replicated join is preserved if there are no equality criteria
assertPlanWithSession(
"SELECT * FROM (SELECT regionkey FROM region GROUP BY regionkey) r1, region r2 WHERE r2.regionkey > r1.regionkey",
Expand Down
Loading
Loading