Skip to content

Commit

Permalink
Remove unnecessary remote exchange from scalar correlated subquery plan
Browse files Browse the repository at this point in the history
Eliminate exchanges between LeftJoin and MarkDistinctNode when planning scalar
correlated subqueries.

Here is a sample query.

SELECT name, (SELECT name FROM region WHERE regionkey = nation.regionkey)
FROM nation

The plan before the change included remote and local exchanges between
LeftJoin and MarkDistinctNode.

- FilterProject
    - MarkDistinct[distinct=unique:bigint marker=is_distinct]
        - LocalExchange[HASH][$hashvalue] ("unique")
            - RemoteExchange[REPARTITION][$hashvalue_15]
                - LeftJoin[("regionkey" = "regionkey_0")]
                    - AssignUniqueId
                    - (build)

After this change the plan no longer includes the exchanges.

- FilterProject
    - MarkDistinct[distinct=unique:bigint marker=is_distinct]
        - LeftJoin[("regionkey" = "regionkey_0")]
            - AssignUniqueId
            - (build)

The change contains two parts:

- update TransformCorrelatedScalarSubquery to include all probe symbols in
  MarkDistinctNode#distinctSymbols;
- add logic to AddLocalExchanges rule to drop everything but `unique` symbol from
  MarkDistinctNode#distinctSymbols.

First change allows AddExchanges rule to see that partitioning required by
MarkDistinctNode is already satified: partitioned_on(join_key) implies
partitioned_on(all_probe_rows). Second change avoids increasing the cost of
MarkDistinctNode.
  • Loading branch information
mbasmanova committed Aug 10, 2018
1 parent 2c6ba08 commit 575a9a6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
context.getIdAllocator().getNextId(),
rewrittenLateralJoinNode,
isDistinct,
ImmutableList.of(unique),
rewrittenLateralJoinNode.getInput().getOutputSymbols(),
Optional.empty());

FilterNode filterNode = new FilterNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ConstantProperty;
import com.facebook.presto.spi.GroupingProperty;
import com.facebook.presto.spi.LocalProperty;
import com.facebook.presto.spi.SortingProperty;
Expand Down Expand Up @@ -366,8 +367,73 @@ public PlanWithProperties visitWindow(WindowNode node, StreamPreferredProperties
public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferredProperties parentPreferences)
{
// mark distinct requires that all data partitioned
StreamPreferredProperties requiredProperties = parentPreferences.withDefaultParallelism(session).withPartitioning(node.getDistinctSymbols());
return planAndEnforceChildren(node, requiredProperties, requiredProperties);
StreamPreferredProperties childRequirements = parentPreferences
.constrainTo(node.getSource().getOutputSymbols())
.withDefaultParallelism(session)
.withPartitioning(node.getDistinctSymbols());

PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements);

MarkDistinctNode result = new MarkDistinctNode(
node.getId(),
child.getNode(),
node.getMarkerSymbol(),
pruneMarkDistinctSymbols(node, child.getProperties().getLocalProperties()),
node.getHashSymbol());

return deriveProperties(result, child.getProperties());
}

/**
* Prune redundant distinct symbols to reduce CPU cost of hashing corresponding values and amount of memory
* needed to store all the distinct values.
*
* Consider the following plan,
*
* - MarkDistinctNode (unique, c1, c2)
* - Join
* - AssignUniqueId (unique)
* - probe (c1, c2)
* - build
*
* In this case MarkDistinctNode (unique, c1, c2) is equivalent to MarkDistinctNode (unique),
* because if two rows match on `unique`, they must match on `c1` and `c2` as well.
*
* More generally, any distinct symbol that is functionally dependent on a subset of
* other distinct symbols can be dropped.
*
* Ideally, this logic would be encapsulated in a separate rule, but currently no rule other
* than AddLocalExchanges can reason about local properties.
*/
private List<Symbol> pruneMarkDistinctSymbols(MarkDistinctNode node, List<LocalProperty<Symbol>> localProperties)
{
if (localProperties.isEmpty()) {
return node.getDistinctSymbols();
}

// Identify functional dependencies between distinct symbols: in the list of local properties any constant
// symbol is functionally dependent on the set of symbols that appears earlier.
ImmutableSet.Builder<Symbol> redundantSymbolsBuilder = ImmutableSet.builder();
for (LocalProperty<Symbol> property : localProperties) {
if (property instanceof ConstantProperty) {
redundantSymbolsBuilder.add(((ConstantProperty<Symbol>) property).getColumn());
}
else if (!node.getDistinctSymbols().containsAll(property.getColumns())) {
// Ran into a non-distinct symbol. There will be no more symbols that are functionally dependent on distinct symbols exclusively.
break;
}
}

Set<Symbol> redundantSymbols = redundantSymbolsBuilder.build();
List<Symbol> remainingSymbols = node.getDistinctSymbols().stream()
.filter(symbol -> !redundantSymbols.contains(symbol))
.collect(toImmutableList());
if (remainingSymbols.isEmpty()) {
// This happens when all distinct symbols are constants.
// In that case, keep the first symbol (don't drop them all).
return ImmutableList.of(node.getDistinctSymbols().get(0));
}
return remainingSymbols;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import static com.facebook.presto.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT;
import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION;
import static com.facebook.presto.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS;
import static com.facebook.presto.spi.predicate.Domain.singleValue;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.VarcharType.createVarcharType;
Expand All @@ -68,6 +69,7 @@
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
Expand All @@ -94,6 +96,7 @@
import static com.facebook.presto.tests.QueryTemplate.queryTemplate;
import static com.facebook.presto.util.MorePredicates.isInstanceOfAny;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;

Expand Down Expand Up @@ -159,6 +162,18 @@ public void testDistinctLimitOverInequalityJoin()
.withExactOutputs(ImmutableList.of("O_ORDERKEY"))))));
}

@Test
public void testDistinctOverConstants()
{
assertPlan("SELECT count(*), count(distinct orderkey) FROM (SELECT * FROM orders WHERE orderkey = 1)",
anyTree(
markDistinct("is_distinct", ImmutableList.of("orderkey"), "hash",
anyTree(
project(ImmutableMap.of("hash", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(orderkey), 0))")),
filter("orderkey = BIGINT '1'",
tableScan("orders", ImmutableMap.of("orderkey", "orderkey"))))))));
}

@Test
public void testInnerInequalityJoinNoEquiJoinConjuncts()
{
Expand Down Expand Up @@ -395,6 +410,22 @@ public void testCorrelatedSubqueries()
tableScan("orders", ImmutableMap.of("X", "orderkey")))));
}

@Test
public void testCorrelatedScalarSubqueryInSelect()
{
assertDistributedPlan("SELECT name, (SELECT name FROM region WHERE regionkey = nation.regionkey) FROM nation",
anyTree(
filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%s, 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
project(
markDistinct("is_distinct", ImmutableList.of("unique"), "hash",
project(ImmutableMap.of("hash", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(unique), 0))")),
join(LEFT, ImmutableList.of(equiJoinClause("n_regionkey", "r_regionkey")),
assignUniqueId("unique",
exchange(REMOTE, REPARTITION,
anyTree(tableScan("nation", ImmutableMap.of("n_regionkey", "regionkey"))))),
anyTree(tableScan("region", ImmutableMap.of("r_regionkey", "regionkey"))))))))));
}

@Test
public void testStreamingAggregationForCorrelatedSubquery()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
public class TestTransformCorrelatedScalarSubquery
extends BaseRuleTest
{
private static final ImmutableList<List<Expression>> ONE_ROW = ImmutableList.of(ImmutableList.of());
private static final ImmutableList<List<Expression>> TWO_ROWS = ImmutableList.of(ImmutableList.of(), ImmutableList.of());
private static final ImmutableList<List<Expression>> ONE_ROW = ImmutableList.of(ImmutableList.of(new LongLiteral("1")));
private static final ImmutableList<List<Expression>> TWO_ROWS = ImmutableList.of(ImmutableList.of(new LongLiteral("1")), ImmutableList.of(new LongLiteral("2")));

private Rule rule = new TransformCorrelatedScalarSubquery();

Expand Down Expand Up @@ -101,7 +101,7 @@ public void rewritesOnSubqueryWithoutProjection()
ensureScalarSubquery(),
markDistinct(
"is_distinct",
ImmutableList.of("unique"),
ImmutableList.of("corr", "unique"),
lateral(
ImmutableList.of("corr"),
assignUniqueId(
Expand Down Expand Up @@ -131,7 +131,7 @@ public void rewritesOnSubqueryWithProjection()
ensureScalarSubquery(),
markDistinct(
"is_distinct",
ImmutableList.of("unique"),
ImmutableList.of("corr", "unique"),
lateral(
ImmutableList.of("corr"),
assignUniqueId(
Expand Down Expand Up @@ -163,7 +163,7 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode()
ensureScalarSubquery(),
markDistinct(
"is_distinct",
ImmutableList.of("unique"),
ImmutableList.of("corr", "unique"),
lateral(
ImmutableList.of("corr"),
assignUniqueId(
Expand Down

0 comments on commit 575a9a6

Please sign in to comment.