Skip to content

Commit

Permalink
Alter ExpressionExtractor to only return RowExpression
Browse files Browse the repository at this point in the history
Rather than having an equivalent RowExpressionExtractor, we always
return RowExpression from ExpressionExtractor in order to serve a single
source of truth when collecting row expressions or expressions from a
PlanNode.
  • Loading branch information
highker committed Apr 17, 2019
1 parent 1b05879 commit 4e98987
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 32 deletions.
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.AggregationNode;
Expand All @@ -22,36 +23,36 @@
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.collect.ImmutableList;

import java.util.List;

import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class ExpressionExtractor
{
public static List<Expression> extractExpressions(PlanNode plan)
public static List<RowExpression> extractExpressions(PlanNode plan)
{
return extractExpressions(plan, noLookup());
}

public static List<Expression> extractExpressions(PlanNode plan, Lookup lookup)
public static List<RowExpression> extractExpressions(PlanNode plan, Lookup lookup)
{
requireNonNull(plan, "plan is null");
requireNonNull(lookup, "lookup is null");

ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> expressionsBuilder = ImmutableList.builder();
plan.accept(new Visitor(true, lookup), expressionsBuilder);
return expressionsBuilder.build();
}

public static List<Expression> extractExpressionsNonRecursive(PlanNode plan)
public static List<RowExpression> extractExpressionsNonRecursive(PlanNode plan)
{
ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
ImmutableList.Builder<RowExpression> expressionsBuilder = ImmutableList.builder();
plan.accept(new Visitor(false, noLookup()), expressionsBuilder);
return expressionsBuilder.build();
}
Expand All @@ -61,7 +62,7 @@ private ExpressionExtractor()
}

private static class Visitor
extends SimplePlanVisitor<ImmutableList.Builder<Expression>>
extends SimplePlanVisitor<ImmutableList.Builder<RowExpression>>
{
private final boolean recursive;
private final Lookup lookup;
Expand All @@ -73,7 +74,7 @@ private static class Visitor
}

@Override
protected Void visitPlan(PlanNode node, ImmutableList.Builder<Expression> context)
protected Void visitPlan(PlanNode node, ImmutableList.Builder<RowExpression> context)
{
if (recursive) {
return super.visitPlan(node, context);
Expand All @@ -82,55 +83,55 @@ protected Void visitPlan(PlanNode node, ImmutableList.Builder<Expression> contex
}

@Override
public Void visitGroupReference(GroupReference node, ImmutableList.Builder<Expression> context)
public Void visitGroupReference(GroupReference node, ImmutableList.Builder<RowExpression> context)
{
return lookup.resolve(node).accept(this, context);
}

@Override
public Void visitAggregation(AggregationNode node, ImmutableList.Builder<Expression> context)
public Void visitAggregation(AggregationNode node, ImmutableList.Builder<RowExpression> context)
{
node.getAggregations().values()
.forEach(aggregation -> context.add(aggregation.getCall()));
.forEach(aggregation -> context.add(castToRowExpression(aggregation.getCall())));
return super.visitAggregation(node, context);
}

@Override
public Void visitFilter(FilterNode node, ImmutableList.Builder<Expression> context)
public Void visitFilter(FilterNode node, ImmutableList.Builder<RowExpression> context)
{
context.add(node.getPredicate());
context.add(castToRowExpression(node.getPredicate()));
return super.visitFilter(node, context);
}

@Override
public Void visitProject(ProjectNode node, ImmutableList.Builder<Expression> context)
public Void visitProject(ProjectNode node, ImmutableList.Builder<RowExpression> context)
{
context.addAll(node.getAssignments().getExpressions());
context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()));
return super.visitProject(node, context);
}

@Override
public Void visitJoin(JoinNode node, ImmutableList.Builder<Expression> context)
public Void visitJoin(JoinNode node, ImmutableList.Builder<RowExpression> context)
{
node.getFilter().ifPresent(context::add);
node.getFilter().map(OriginalExpressionUtils::castToRowExpression).ifPresent(context::add);
return super.visitJoin(node, context);
}

@Override
public Void visitValues(ValuesNode node, ImmutableList.Builder<Expression> context)
public Void visitValues(ValuesNode node, ImmutableList.Builder<RowExpression> context)
{
node.getRows().forEach(rowExpressions -> rowExpressions.forEach(rowExpression -> {
if (isExpression(rowExpression)) {
context.add(castToExpression(rowExpression));
}
}));
node.getRows().forEach(context::addAll);
return super.visitValues(node, context);
}

@Override
public Void visitApply(ApplyNode node, ImmutableList.Builder<Expression> context)
public Void visitApply(ApplyNode node, ImmutableList.Builder<RowExpression> context)
{
context.addAll(node.getSubqueryAssignments().getExpressions());
context.addAll(node.getSubqueryAssignments()
.getExpressions()
.stream()
.map(OriginalExpressionUtils::castToRowExpression)
.collect(toImmutableList()));
return super.visitApply(node, context);
}
}
Expand Down
Expand Up @@ -27,6 +27,7 @@
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
Expand Down Expand Up @@ -504,6 +505,8 @@ private Set<Expression> extractOuterColumnReferences(PlanNode planNode)
// when reference expression is not rewritten that means it cannot be satisfied within given PlaNode
// see that TranslationMap only resolves (local) fields in current scope
return ExpressionExtractor.extractExpressions(planNode).stream()
.filter(OriginalExpressionUtils::isExpression)
.map(OriginalExpressionUtils::castToExpression)
.flatMap(expression -> extractColumnReferences(expression, analysis.getColumnReferences()).stream())
.collect(toImmutableSet());
}
Expand Down
Expand Up @@ -36,6 +36,8 @@
import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive;
import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression;
import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;
Expand All @@ -47,23 +49,23 @@ private SymbolsExtractor() {}
public static Set<Symbol> extractUnique(PlanNode node)
{
ImmutableSet.Builder<Symbol> uniqueSymbols = ImmutableSet.builder();
extractExpressions(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression)));
extractExpressions(node).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression)));

return uniqueSymbols.build();
}

public static Set<Symbol> extractUniqueNonRecursive(PlanNode node)
{
ImmutableSet.Builder<Symbol> uniqueSymbols = ImmutableSet.builder();
extractExpressionsNonRecursive(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression)));
extractExpressionsNonRecursive(node).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression)));

return uniqueSymbols.build();
}

public static Set<Symbol> extractUnique(PlanNode node, Lookup lookup)
{
ImmutableSet.Builder<Symbol> uniqueSymbols = ImmutableSet.builder();
extractExpressions(node, lookup).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression)));
extractExpressions(node, lookup).forEach(expression -> uniqueSymbols.addAll(extractUniqueInternal(expression)));

return uniqueSymbols.build();
}
Expand Down Expand Up @@ -124,6 +126,17 @@ public static Set<Symbol> extractOutputSymbols(PlanNode planNode, Lookup lookup)
.collect(toImmutableSet());
}

/**
* {@param expression} could be an OriginalExpression
*/
private static Set<Symbol> extractUniqueInternal(RowExpression expression)
{
if (isExpression(expression)) {
return extractUnique(castToExpression(expression));
}
return extractUnique(expression);
}

private static class SymbolBuilderVisitor
extends DefaultExpressionTraversalVisitor<Void, ImmutableList.Builder<Symbol>>
{
Expand Down
Expand Up @@ -21,17 +21,26 @@
import com.facebook.presto.sql.planner.ExpressionExtractor;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Identifier;

import java.util.List;

import static com.google.common.collect.ImmutableList.toImmutableList;

public final class NoIdentifierLeftChecker
implements PlanSanityChecker.Checker
{
@Override
public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector)
{
List<Identifier> identifiers = ExpressionTreeUtils.extractExpressions(ExpressionExtractor.extractExpressions(plan), Identifier.class);
List<Identifier> identifiers = ExpressionTreeUtils.extractExpressions(
ExpressionExtractor.extractExpressions(plan)
.stream()
.filter(OriginalExpressionUtils::isExpression)
.map(OriginalExpressionUtils::castToExpression)
.collect(toImmutableList()),
Identifier.class);
if (!identifiers.isEmpty()) {
throw new IllegalStateException("Unexpected identifier in logical plan: " + identifiers.get(0));
}
Expand Down
Expand Up @@ -20,10 +20,14 @@
import com.facebook.presto.sql.planner.ExpressionExtractor;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SubqueryExpression;

import java.util.List;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;

public final class NoSubqueryExpressionLeftChecker
Expand All @@ -32,7 +36,12 @@ public final class NoSubqueryExpressionLeftChecker
@Override
public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector)
{
for (Expression expression : ExpressionExtractor.extractExpressions(plan)) {
List<Expression> expressions = ExpressionExtractor.extractExpressions(plan)
.stream()
.filter(OriginalExpressionUtils::isExpression)
.map(OriginalExpressionUtils::castToExpression)
.collect(toImmutableList());
for (Expression expression : expressions) {
new DefaultTraversalVisitor<Void, Void>()
{
@Override
Expand Down

0 comments on commit 4e98987

Please sign in to comment.