Skip to content

Commit

Permalink
Change Assignments to use VariableReferenceExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
rongrong committed Jun 11, 2019
1 parent 4c09072 commit dc140d2
Show file tree
Hide file tree
Showing 105 changed files with 701 additions and 646 deletions.
Expand Up @@ -15,7 +15,7 @@

import com.facebook.presto.Session;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.ProjectNode;
Expand Down Expand Up @@ -53,7 +53,7 @@ protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsPro
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
.setOutputRowCount(sourceStats.getOutputRowCount());

for (Map.Entry<Symbol, Expression> entry : node.getAssignments().entrySet()) {
for (Map.Entry<VariableReferenceExpression, Expression> entry : node.getAssignments().entrySet()) {
calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types));
}
return Optional.of(calculatedStats.build());
Expand Down
Expand Up @@ -74,7 +74,7 @@ public class EffectivePredicateExtractor
private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION =
entry -> entry.getValue().equals(entry.getKey().toSymbolReference());

private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY =
private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> SYMBOL_ENTRY_TO_EQUALITY =
entry -> {
SymbolReference reference = entry.getKey().toSymbolReference();
Expression expression = entry.getValue();
Expand All @@ -83,6 +83,18 @@ public class EffectivePredicateExtractor
return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, reference, expression);
};

private static final Predicate<Map.Entry<VariableReferenceExpression, ? extends Expression>> VARIABLE_MATCHES_EXPRESSION =
entry -> entry.getValue().equals(new SymbolReference(entry.getKey().getName()));

private static final Function<Map.Entry<VariableReferenceExpression, ? extends Expression>, Expression> VARIABLE_ENTRY_TO_EQUALITY =
entry -> {
SymbolReference reference = new SymbolReference(entry.getKey().getName());
Expression expression = entry.getValue();
// TODO: this is not correct with respect to NULLs ('reference IS NULL' would be correct, rather than 'reference = NULL')
// TODO: switch this to 'IS NOT DISTINCT FROM' syntax when EqualityInference properly supports it
return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, reference, expression);
};

private final ExpressionDomainTranslator domainTranslator;

public EffectivePredicateExtractor(ExpressionDomainTranslator domainTranslator)
Expand Down Expand Up @@ -163,8 +175,8 @@ public Expression visitProject(ProjectNode node, Void context)
Expression underlyingPredicate = node.getSource().accept(this, context);

List<Expression> projectionEqualities = node.getAssignments().entrySet().stream()
.filter(SYMBOL_MATCHES_EXPRESSION.negate())
.map(ENTRY_TO_EQUALITY)
.filter(VARIABLE_MATCHES_EXPRESSION.negate())
.map(VARIABLE_ENTRY_TO_EQUALITY)
.collect(toImmutableList());

return pullExpressionThroughSymbols(combineConjuncts(
Expand Down Expand Up @@ -313,7 +325,7 @@ private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Colle

List<Expression> equalities = mapping.apply(i).stream()
.filter(SYMBOL_MATCHES_EXPRESSION.negate())
.map(ENTRY_TO_EQUALITY)
.map(SYMBOL_ENTRY_TO_EQUALITY)
.collect(toImmutableList());

sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts(
Expand All @@ -335,13 +347,6 @@ private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Colle
return combineConjuncts(potentialOutputConjuncts);
}

private static List<Expression> pullExpressionsThroughSymbols(List<Expression> expressions, Collection<Symbol> symbols)
{
return expressions.stream()
.map(expression -> pullExpressionThroughSymbols(expression, symbols))
.collect(toImmutableList());
}

private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols)
{
EqualityInference equalityInference = createEqualityInference(expression);
Expand Down
Expand Up @@ -251,6 +251,7 @@
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReferences;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL;
Expand Down Expand Up @@ -279,7 +280,6 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.IntStream.range;
Expand Down Expand Up @@ -1145,9 +1145,9 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext
PlanNode sourceNode = node.getSource();

RowExpression filterExpression = node.getPredicate();
List<Symbol> outputSymbols = node.getOutputSymbols();
List<VariableReferenceExpression> outputVariables = toVariableReferences(node.getOutputSymbols(), context.getTypes());

return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols);
return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputVariables), outputVariables);
}

@Override
Expand All @@ -1164,9 +1164,7 @@ public PhysicalOperation visitProject(ProjectNode node, LocalExecutionPlanContex
sourceNode = node.getSource();
}

List<Symbol> outputSymbols = node.getOutputSymbols();

return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), outputSymbols);
return visitScanFilterAndProject(context, node.getId(), sourceNode, filterExpression, node.getAssignments(), node.getOutputVariables());
}

// TODO: This should be refactored, so that there's an optimizer that merges scan-filter-project into a single PlanNode
Expand All @@ -1176,7 +1174,7 @@ private PhysicalOperation visitScanFilterAndProject(
PlanNode sourceNode,
Optional<RowExpression> filterExpression,
Assignments assignments,
List<Symbol> outputSymbols)
List<VariableReferenceExpression> outputVariables)
{
// if source is a table scan we fold it directly into the filter and project
// otherwise we plan it as a normal operator
Expand All @@ -1197,8 +1195,6 @@ private PhysicalOperation visitScanFilterAndProject(
Symbol symbol = new Symbol(variable.getName());
sourceLayout.put(symbol, input);

Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol));

channel++;
}
}
Expand All @@ -1216,17 +1212,17 @@ private PhysicalOperation visitScanFilterAndProject(

// build output mapping
ImmutableMap.Builder<Symbol, Integer> outputMappingsBuilder = ImmutableMap.builder();
for (int i = 0; i < outputSymbols.size(); i++) {
Symbol symbol = outputSymbols.get(i);
outputMappingsBuilder.put(symbol, i);
for (int i = 0; i < outputVariables.size(); i++) {
VariableReferenceExpression variable = outputVariables.get(i);
outputMappingsBuilder.put(new Symbol(variable.getName()), i);
}
Map<Symbol, Integer> outputMappings = outputMappingsBuilder.build();

// compiler uses inputs instead of symbols, so rewrite the expressions first

List<Expression> projections = new ArrayList<>();
for (Symbol symbol : outputSymbols) {
projections.add(assignments.get(symbol));
for (VariableReferenceExpression variable : outputVariables) {
projections.add(assignments.get(variable));
}

Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
Expand Down
Expand Up @@ -355,7 +355,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
if (column.isHidden()) {
continue;
}
Symbol output = symbolAllocator.newSymbol(column.getName(), column.getType());
VariableReferenceExpression output = symbolAllocator.newVariable(column.getName(), column.getType());
int index = insert.getColumns().indexOf(columns.get(column.getName()));
if (index < 0) {
Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString());
Expand Down
Expand Up @@ -104,14 +104,14 @@ public PlanBuilder appendProjections(Iterable<Expression> expressions, SymbolAll

// add an identity projection for underlying plan
for (Symbol symbol : getRoot().getOutputSymbols()) {
projections.put(symbol, symbol.toSymbolReference());
projections.put(symbolAllocator.toVariableReference(symbol), symbol.toSymbolReference());
}

ImmutableMap.Builder<Symbol, Expression> newTranslations = ImmutableMap.builder();
for (Expression expression : expressions) {
Symbol symbol = symbolAllocator.newSymbol(expression, getAnalysis().getTypeWithCoercions(expression));
projections.put(symbol, translations.rewrite(expression));
newTranslations.put(symbol, expression);
VariableReferenceExpression variable = symbolAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression));
projections.put(variable, translations.rewrite(expression));
newTranslations.put(new Symbol(variable.getName()), expression);
}
// Now append the new translations into the TranslationMap
for (Map.Entry<Symbol, Expression> entry : newTranslations.build().entrySet()) {
Expand Down

0 comments on commit dc140d2

Please sign in to comment.