Skip to content

Commit

Permalink
Update AddLocalExchanges to plan streaming aggregations
Browse files Browse the repository at this point in the history
Also, fixed dormant bug in local property derivations for cross joins.
  • Loading branch information
mbasmanova committed Jun 20, 2018
1 parent 85c7b3d commit 7775b23
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 6 deletions.
Expand Up @@ -255,7 +255,7 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred
return planAndEnforceChildren(node, singleStream(), defaultParallelism(session));
}

StreamPreferredProperties requiredProperties = parentPreferences.withDefaultParallelism(session).withPartitioning(node.getGroupingKeys());
List<Symbol> groupingKeys = node.getGroupingKeys();
if (node.hasDefaultOutput()) {
checkState(node.isDecomposable(metadata.getFunctionRegistry()));

Expand All @@ -267,13 +267,36 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred
idAllocator.getNextId(),
LOCAL,
child.getNode(),
node.getGroupingKeys(),
groupingKeys,
Optional.empty()),
child.getProperties());
return rebaseAndDeriveProperties(node, ImmutableList.of(exchange));
}

return planAndEnforceChildren(node, requiredProperties, requiredProperties);
StreamPreferredProperties childRequirements = parentPreferences
.constrainTo(node.getSource().getOutputSymbols())
.withDefaultParallelism(session)
.withPartitioning(groupingKeys);

PlanWithProperties child = planAndEnforce(node.getSource(), childRequirements, childRequirements);

List<Symbol> preGroupedSymbols = ImmutableList.of();
if (!LocalProperties.match(child.getProperties().getLocalProperties(), LocalProperties.grouped(groupingKeys)).get(0).isPresent()) {
// !isPresent() indicates the property was satisfied completely
preGroupedSymbols = groupingKeys;
}

AggregationNode result = new AggregationNode(
node.getId(),
child.getNode(),
node.getAggregations(),
node.getGroupingSets(),
preGroupedSymbols,
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol());

return deriveProperties(result, child.getProperties());
}

@Override
Expand Down
Expand Up @@ -383,6 +383,16 @@ public ActualProperties visitJoin(JoinNode node, List<ActualProperties> inputPro
constants.putAll(probeProperties.getConstants());
constants.putAll(buildProperties.getConstants());

if (node.isCrossJoin()) {
// Cross join preserves only constants from probe and build sides.
// Cross join doesn't preserve sorting or grouping local properties on either side.
return ActualProperties.builder()
.global(probeProperties)
.local(ImmutableList.of())
.constants(constants)
.build();
}

return ActualProperties.builderFrom(probeProperties)
.constants(constants)
.unordered(unordered)
Expand Down
Expand Up @@ -73,6 +73,7 @@
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
Expand Down Expand Up @@ -355,6 +356,42 @@ public void testCorrelatedSubqueries()
tableScan("orders", ImmutableMap.of("X", "orderkey")))));
}

@Test
public void testStreamingAggregationOverJoin()
{
// "orders" table is naturally grouped on orderkey
// this grouping should survive an inner join and allow for streaming aggregation later
// this grouping should not survive a cross join
assertPlan("SELECT o.orderkey, count(*) FROM orders o, lineitem l WHERE o.orderkey=l.orderkey GROUP BY 1",
anyTree(
aggregation(
ImmutableList.of(ImmutableList.of("o_orderkey")),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableList.of("o_orderkey"), // streaming
ImmutableMap.of(),
Optional.empty(),
SINGLE,
join(INNER, ImmutableList.of(equiJoinClause("o_orderkey", "l_orderkey")),
anyTree(
tableScan("orders", ImmutableMap.of("o_orderkey", "orderkey"))),
anyTree(
tableScan("lineitem", ImmutableMap.of("l_orderkey", "orderkey")))))));

assertPlan("SELECT o.orderkey, count(*) FROM orders o, lineitem l GROUP BY 1",
anyTree(
aggregation(
ImmutableList.of(ImmutableList.of("orderkey")),
ImmutableMap.of(Optional.empty(), functionCall("count", ImmutableList.of())),
ImmutableList.of(), // not streaming
ImmutableMap.of(),
Optional.empty(),
SINGLE,
join(INNER, ImmutableList.of(),
tableScan("orders", ImmutableMap.of("orderkey", "orderkey")),
anyTree(
node(TableScanNode.class))))));
}

/**
* Handling of correlated IN pulls up everything possible to the generated outer join condition.
* This test ensures uncorrelated conditions are pushed back down.
Expand Down
Expand Up @@ -38,13 +38,15 @@ public class AggregationMatcher
{
private final Map<Symbol, Symbol> masks;
private final List<List<String>> groupingSets;
private final List<String> preGroupedSymbols;
private final Optional<Symbol> groupId;
private final Step step;

public AggregationMatcher(List<List<String>> groupingSets, Map<Symbol, Symbol> masks, Optional<Symbol> groupId, Step step)
public AggregationMatcher(List<List<String>> groupingSets, List<String> preGroupedSymbols, Map<Symbol, Symbol> masks, Optional<Symbol> groupId, Step step)
{
this.masks = masks;
this.groupingSets = groupingSets;
this.preGroupedSymbols = preGroupedSymbols;
this.groupId = groupId;
this.step = step;
}
Expand Down Expand Up @@ -96,7 +98,7 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses
return NO_MATCH;
}

if (!aggregationNode.getPreGroupedSymbols().isEmpty()) {
if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedSymbols(), symbolAliases)) {
return NO_MATCH;
}

Expand Down Expand Up @@ -126,6 +128,7 @@ public String toString()
{
return toStringHelper(this)
.add("groupingSets", groupingSets)
.add("preGroupedSymbols", preGroupedSymbols)
.add("masks", masks)
.add("groudId", groupId)
.add("step", step)
Expand Down
Expand Up @@ -218,7 +218,19 @@ public static PlanMatchPattern aggregation(
Step step,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source).with(new AggregationMatcher(groupingSets, masks, groupId, step));
return aggregation(groupingSets, aggregations, ImmutableList.of(), masks, groupId, step, source);
}

public static PlanMatchPattern aggregation(
List<List<String>> groupingSets,
Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregations,
List<String> preGroupedSymbols,
Map<Symbol, Symbol> masks,
Optional<Symbol> groupId,
Step step,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source).with(new AggregationMatcher(groupingSets, preGroupedSymbols, masks, groupId, step));
aggregations.entrySet().forEach(
aggregation -> result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue())));
return result;
Expand Down

0 comments on commit 7775b23

Please sign in to comment.