From 4e98987e44bf7a5eaa2fde6de499e833cffa39b3 Mon Sep 17 00:00:00 2001 From: James Sun Date: Wed, 3 Apr 2019 22:53:20 -0700 Subject: [PATCH] Alter ExpressionExtractor to only return RowExpression Rather than having an equivalent RowExpressionExtractor, we always return RowExpression from ExpressionExtractor in order to serve a single source of truth when collecting row expressions or expressions from a PlanNode. --- .../sql/planner/ExpressionExtractor.java | 55 ++++++++++--------- .../presto/sql/planner/SubqueryPlanner.java | 3 + .../presto/sql/planner/SymbolsExtractor.java | 19 ++++++- .../sanity/NoIdentifierLeftChecker.java | 11 +++- .../NoSubqueryExpressionLeftChecker.java | 11 +++- 5 files changed, 67 insertions(+), 32 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java index ac998ebda517..3ca1568ba69c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -22,36 +23,36 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; -import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import java.util.List; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; -import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; -import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class ExpressionExtractor { - public static List extractExpressions(PlanNode plan) + public static List extractExpressions(PlanNode plan) { return extractExpressions(plan, noLookup()); } - public static List extractExpressions(PlanNode plan, Lookup lookup) + public static List extractExpressions(PlanNode plan, Lookup lookup) { requireNonNull(plan, "plan is null"); requireNonNull(lookup, "lookup is null"); - ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); + ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); plan.accept(new Visitor(true, lookup), expressionsBuilder); return expressionsBuilder.build(); } - public static List extractExpressionsNonRecursive(PlanNode plan) + public static List extractExpressionsNonRecursive(PlanNode plan) { - ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); + ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); plan.accept(new Visitor(false, noLookup()), expressionsBuilder); return expressionsBuilder.build(); } @@ -61,7 +62,7 @@ private ExpressionExtractor() } private static class Visitor - extends SimplePlanVisitor> + extends SimplePlanVisitor> { private final boolean recursive; private final Lookup lookup; @@ -73,7 +74,7 @@ private static class Visitor } @Override - protected Void visitPlan(PlanNode node, ImmutableList.Builder context) + protected Void visitPlan(PlanNode node, ImmutableList.Builder context) { if (recursive) { return super.visitPlan(node, context); @@ -82,55 +83,55 @@ protected Void visitPlan(PlanNode node, ImmutableList.Builder contex } @Override - public Void visitGroupReference(GroupReference node, ImmutableList.Builder context) + public Void visitGroupReference(GroupReference node, ImmutableList.Builder context) { return lookup.resolve(node).accept(this, context); } @Override - public Void visitAggregation(AggregationNode node, ImmutableList.Builder context) + public Void visitAggregation(AggregationNode node, ImmutableList.Builder context) { node.getAggregations().values() - .forEach(aggregation -> context.add(aggregation.getCall())); + .forEach(aggregation -> context.add(castToRowExpression(aggregation.getCall()))); return super.visitAggregation(node, context); } @Override - public Void visitFilter(FilterNode node, ImmutableList.Builder context) + public Void visitFilter(FilterNode node, ImmutableList.Builder context) { - context.add(node.getPredicate()); + context.add(castToRowExpression(node.getPredicate())); return super.visitFilter(node, context); } @Override - public Void visitProject(ProjectNode node, ImmutableList.Builder context) + public Void visitProject(ProjectNode node, ImmutableList.Builder context) { - context.addAll(node.getAssignments().getExpressions()); + context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())); return super.visitProject(node, context); } @Override - public Void visitJoin(JoinNode node, ImmutableList.Builder context) + public Void visitJoin(JoinNode node, ImmutableList.Builder context) { - node.getFilter().ifPresent(context::add); + node.getFilter().map(OriginalExpressionUtils::castToRowExpression).ifPresent(context::add); return super.visitJoin(node, context); } @Override - public Void visitValues(ValuesNode node, ImmutableList.Builder context) + public Void visitValues(ValuesNode node, ImmutableList.Builder context) { - node.getRows().forEach(rowExpressions -> rowExpressions.forEach(rowExpression -> { - if (isExpression(rowExpression)) { - context.add(castToExpression(rowExpression)); - } - })); + node.getRows().forEach(context::addAll); return super.visitValues(node, context); } @Override - public Void visitApply(ApplyNode node, ImmutableList.Builder context) + public Void visitApply(ApplyNode node, ImmutableList.Builder context) { - context.addAll(node.getSubqueryAssignments().getExpressions()); + context.addAll(node.getSubqueryAssignments() + .getExpressions() + .stream() + .map(OriginalExpressionUtils::castToRowExpression) + .collect(toImmutableList())); return super.visitApply(node, context); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index 0dbf9bdc552c..de816fb31d00 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; import com.facebook.presto.sql.tree.DereferenceExpression; @@ -504,6 +505,8 @@ private Set extractOuterColumnReferences(PlanNode planNode) // when reference expression is not rewritten that means it cannot be satisfied within given PlaNode // see that TranslationMap only resolves (local) fields in current scope return ExpressionExtractor.extractExpressions(planNode).stream() + .filter(OriginalExpressionUtils::isExpression) + .map(OriginalExpressionUtils::castToExpression) .flatMap(expression -> extractColumnReferences(expression, analysis.getColumnReferences()).stream()) .collect(toImmutableSet()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java index 1571b40e954f..561c183e5d9c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java @@ -36,6 +36,8 @@ import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; @@ -47,7 +49,7 @@ private SymbolsExtractor() {} public static Set extractUnique(PlanNode node) { ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); - extractExpressions(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression))); + extractExpressions(node).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression))); return uniqueSymbols.build(); } @@ -55,7 +57,7 @@ public static Set extractUnique(PlanNode node) public static Set extractUniqueNonRecursive(PlanNode node) { ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); - extractExpressionsNonRecursive(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression))); + extractExpressionsNonRecursive(node).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression))); return uniqueSymbols.build(); } @@ -63,7 +65,7 @@ public static Set extractUniqueNonRecursive(PlanNode node) public static Set extractUnique(PlanNode node, Lookup lookup) { ImmutableSet.Builder uniqueSymbols = ImmutableSet.builder(); - extractExpressions(node, lookup).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression))); + extractExpressions(node, lookup).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression))); return uniqueSymbols.build(); } @@ -124,6 +126,17 @@ public static Set extractOutputSymbols(PlanNode planNode, Lookup lookup) .collect(toImmutableSet()); } + /** + * {@param expression} could be an OriginalExpression + */ + private static Set extractUniqueInternal(RowExpression expression) + { + if (isExpression(expression)) { + return extractUnique(castToExpression(expression)); + } + return extractUnique(expression); + } + private static class SymbolBuilderVisitor extends DefaultExpressionTraversalVisitor> { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoIdentifierLeftChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoIdentifierLeftChecker.java index cec8bb662a81..0f14a7306289 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoIdentifierLeftChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoIdentifierLeftChecker.java @@ -21,17 +21,26 @@ import com.facebook.presto.sql.planner.ExpressionExtractor; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Identifier; import java.util.List; +import static com.google.common.collect.ImmutableList.toImmutableList; + public final class NoIdentifierLeftChecker implements PlanSanityChecker.Checker { @Override public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) { - List identifiers = ExpressionTreeUtils.extractExpressions(ExpressionExtractor.extractExpressions(plan), Identifier.class); + List identifiers = ExpressionTreeUtils.extractExpressions( + ExpressionExtractor.extractExpressions(plan) + .stream() + .filter(OriginalExpressionUtils::isExpression) + .map(OriginalExpressionUtils::castToExpression) + .collect(toImmutableList()), + Identifier.class); if (!identifiers.isEmpty()) { throw new IllegalStateException("Unexpected identifier in logical plan: " + identifiers.get(0)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java index 983cdad6d4a7..39bdcc211a7a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/NoSubqueryExpressionLeftChecker.java @@ -20,10 +20,14 @@ import com.facebook.presto.sql.planner.ExpressionExtractor; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SubqueryExpression; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; public final class NoSubqueryExpressionLeftChecker @@ -32,7 +36,12 @@ public final class NoSubqueryExpressionLeftChecker @Override public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) { - for (Expression expression : ExpressionExtractor.extractExpressions(plan)) { + List expressions = ExpressionExtractor.extractExpressions(plan) + .stream() + .filter(OriginalExpressionUtils::isExpression) + .map(OriginalExpressionUtils::castToExpression) + .collect(toImmutableList()); + for (Expression expression : expressions) { new DefaultTraversalVisitor() { @Override