Skip to content

Commit

Permalink
Push projections through exchanges
Browse files Browse the repository at this point in the history
  • Loading branch information
erichwang committed Oct 7, 2015
1 parent 1763d07 commit 631eb42
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 25 deletions.
Expand Up @@ -20,6 +20,8 @@

import java.util.Map;

import static com.google.common.base.Preconditions.checkState;

public class ExpressionSymbolInliner
extends ExpressionRewriter<Void>
{
Expand All @@ -33,6 +35,8 @@ public ExpressionSymbolInliner(Map<Symbol, ? extends Expression> mappings)
@Override
public Expression rewriteQualifiedNameReference(QualifiedNameReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
return mappings.get(Symbol.fromQualifiedName(node.getName()));
Expression expression = mappings.get(Symbol.fromQualifiedName(node.getName()));
checkState(expression != null, "Cannot resolve symbol %s", node.getName());
return expression;
}
}
Expand Up @@ -97,6 +97,7 @@ public PlanOptimizersFactory(Metadata metadata, SqlParser sqlParser, IndexManage
builder.add(new PickLayout(metadata));

builder.add(new PredicatePushDown(metadata, sqlParser)); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate
builder.add(new ProjectionPushDown());
builder.add(new MergeProjections());
builder.add(new UnaliasSymbolReferences()); // Run unalias after merging projections to simplify projections more efficiently
builder.add(new PruneUnreferencedOutputs());
Expand Down
Expand Up @@ -15,26 +15,28 @@

import com.facebook.presto.Session;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.ExpressionSymbolInliner;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanRewriter;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class ProjectionPushDown
Expand Down Expand Up @@ -64,21 +66,22 @@ public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator
this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null");
}

/**
* Convert a plan of the shape ... -> Project -> Union -> ... to ... -> Union -> Project -> ...
*/
@Override
public PlanNode visitProject(ProjectNode node, RewriteContext<Void> context)
{
// If we have a Project on a Union, push the Project through the Union
PlanNode source = context.rewrite(node.getSource());

if (!(source instanceof UnionNode)) {
return context.replaceChildren(node, ImmutableList.of(source));
if (source instanceof UnionNode) {
return pushProjectionThrough(node, (UnionNode) source);
}
else if (source instanceof ExchangeNode) {
return pushProjectionThrough(node, (ExchangeNode) source);
}
return context.replaceChildren(node, ImmutableList.of(source));
}

UnionNode unionNode = (UnionNode) source;

private PlanNode pushProjectionThrough(ProjectNode node, UnionNode source)
{
// OutputLayout of the resultant Union, will be same as the layout of the Project
List<Symbol> outputLayout = node.getOutputSymbols();

Expand All @@ -88,8 +91,8 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Void> context)
// sources for the resultant UnionNode
ImmutableList.Builder<PlanNode> outputSources = ImmutableList.builder();

for (int i = 0; i < unionNode.getSources().size(); i++) {
Map<Symbol, QualifiedNameReference> outputToInput = unionNode.sourceSymbolMap(i); // Map: output of union -> input of this source to the union
for (int i = 0; i < source.getSources().size(); i++) {
Map<Symbol, QualifiedNameReference> outputToInput = source.sourceSymbolMap(i); // Map: output of union -> input of this source to the union
ImmutableMap.Builder<Symbol, Expression> assignments = ImmutableMap.builder(); // assignments for the new ProjectNode

// mapping from current ProjectNode to new ProjectNode, used to identify the output layout
Expand All @@ -103,26 +106,54 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Void> context)
assignments.put(symbol, translatedExpression);
projectSymbolMapping.put(entry.getKey(), symbol);
}
outputSources.add(new ProjectNode(idAllocator.getNextId(), unionNode.getSources().get(i), assignments.build()));
outputSources.add(new ProjectNode(idAllocator.getNextId(), source.getSources().get(i), assignments.build()));
outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol)));
}

return new UnionNode(node.getId(), outputSources.build(), mappings.build());
}

private PlanNode pushProjectionThrough(ProjectNode node, ExchangeNode exchange)
{
ImmutableList.Builder<PlanNode> newSourceBuilder = ImmutableList.builder();
for (int i = 0; i < exchange.getSources().size(); i++) {
Map<Symbol, QualifiedNameReference> outputToInputMap = extractExchangeOutputToInput(exchange, i);

Map<Symbol, Expression> projections = new LinkedHashMap<>(); // Use LinkedHashMap to make output symbol order deterministic
if (exchange.getHashSymbol().isPresent()) {
// Need to retain the hash symbol for the exchange
projections.put(exchange.getHashSymbol().get(), exchange.getHashSymbol().get().toQualifiedNameReference());
}
for (Map.Entry<Symbol, Expression> projection : node.getAssignments().entrySet()) {
projections.put(projection.getKey(), translateExpression(projection.getValue(), outputToInputMap));
}
newSourceBuilder.add(new ProjectNode(idAllocator.getNextId(), exchange.getSources().get(i), projections));
}
List<PlanNode> newSources = newSourceBuilder.build();
return new ExchangeNode(
exchange.getId(),
exchange.getType(),
exchange.getPartitionKeys(),
exchange.getHashSymbol(),
newSources,
newSources.get(0).getOutputSymbols(), // All sources will have same output symbols
newSources.stream()
.map(PlanNode::getOutputSymbols)
.collect(ImmutableCollectors.toImmutableList()));
}
}

private static Expression translateExpression(Expression inputExpression, Map<Symbol, QualifiedNameReference> symbolMapping)
private static Map<Symbol, QualifiedNameReference> extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
{
@Override
public Expression rewriteQualifiedNameReference(QualifiedNameReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
QualifiedNameReference qualifiedNameReference = symbolMapping.get(Symbol.fromQualifiedName(node.getName()));
checkState(qualifiedNameReference != null, "Cannot resolve symbol %s", node.getName());
Map<Symbol, QualifiedNameReference> outputToInputMap = new HashMap<>();
for (int i = 0; i < exchange.getOutputSymbols().size(); i++) {
outputToInputMap.put(exchange.getOutputSymbols().get(i), exchange.getInputs().get(sourceIndex).get(i).toQualifiedNameReference());
}
return outputToInputMap;
}

return qualifiedNameReference;
}
}, inputExpression);
private static Expression translateExpression(Expression inputExpression, Map<Symbol, QualifiedNameReference> symbolMapping)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionSymbolInliner(symbolMapping), inputExpression);
}
}
Expand Up @@ -3628,6 +3628,18 @@ public void testTopNByMultipleFields()
"SELECT orderkey, custkey, orderstatus FROM orders ORDER BY nullif(orderkey, 3) ASC NULLS LAST, custkey ASC LIMIT 10");
}

@Test
public void testExchangeWithProjectionPushDown()
throws Exception
{
assertQuery(
"SELECT * FROM \n" +
" (SELECT orderkey + 1 orderkey FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 100)) o \n" +
"JOIN \n" +
" (SELECT orderkey + 1 orderkey FROM (SELECT * FROM orders ORDER BY orderkey LIMIT 100)) o1 \n" +
"ON (o.orderkey = o1.orderkey)");
}

@Test
public void testUnionWithProjectionPushDown()
throws Exception
Expand Down

0 comments on commit 631eb42

Please sign in to comment.