Skip to content

Commit

Permalink
Introduce ApplyNode representation of subquery
Browse files Browse the repository at this point in the history
For more design details please see:
https://docs.google.com/document/d/18HN7peS2eR8lZsErqcmnoWyMEPb6p4OQeidH1JP_EkA

Apply node is a generic representation of subquery abstract correlation
between outer and nested queries.
Subquery is rewritten into an ApplyNode then if ApplyNode does not have
any correlation it is removed (rewritten to some form of join) by
UncorrelatedInPredicateApplyRemover and UncorrelatedScalarApplyRemover.
  • Loading branch information
kokosing authored and martint committed May 11, 2016
1 parent e247675 commit b5d8bf5
Show file tree
Hide file tree
Showing 25 changed files with 772 additions and 146 deletions.
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
Expand All @@ -24,12 +25,21 @@
import java.util.List;
import java.util.Set;

import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressions;
import static java.util.Objects.requireNonNull;

public final class DependencyExtractor
{
private DependencyExtractor() {}

public static Set<Symbol> extractUnique(PlanNode node)
{
ImmutableSet.Builder<Symbol> uniqueSymbols = ImmutableSet.builder();
extractExpressions(node).forEach(expression -> uniqueSymbols.addAll(extractUnique(expression)));

return uniqueSymbols.build();
}

public static Set<Symbol> extractUnique(Expression expression)
{
return ImmutableSet.copyOf(extractAll(expression));
Expand Down
@@ -0,0 +1,65 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;

import java.util.List;

public class ExpressionExtractor
extends SimplePlanVisitor<ImmutableList.Builder<Expression>>
{
public static List<Expression> extractExpressions(PlanNode plan)
{
ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
plan.accept(new ExpressionExtractor(), expressionsBuilder);
return expressionsBuilder.build();
}

@Override
public Void visitFilter(FilterNode node, ImmutableList.Builder<Expression> context)
{
context.add(node.getPredicate());
return super.visitFilter(node, context);
}

@Override
public Void visitProject(ProjectNode node, ImmutableList.Builder<Expression> context)
{
context.addAll(node.getAssignments().values());
return super.visitProject(node, context);
}

@Override
public Void visitTableScan(TableScanNode node, ImmutableList.Builder<Expression> context)
{
if (node.getOriginalConstraint() != null) {
context.add(node.getOriginalConstraint());
}
return super.visitTableScan(node, context);
}

@Override
public Void visitValues(ValuesNode node, ImmutableList.Builder<Expression> context)
{
node.getRows().forEach(context::addAll);
return super.visitValues(node, context);
}
}
Expand Up @@ -22,6 +22,11 @@
public class ExpressionNodeInliner
extends ExpressionRewriter<Void>
{
public static Expression replaceExpression(Expression expression, Map<? extends Expression, ? extends Expression> mappings)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionNodeInliner(mappings), expression);
}

private final Map<? extends Expression, ? extends Expression> mappings;

public ExpressionNodeInliner(Map<? extends Expression, ? extends Expression> mappings)
Expand Down
Expand Up @@ -96,9 +96,6 @@ public Plan plan(Analysis analysis)
{
PlanNode root = planStatement(analysis, analysis.getStatement());

// make sure we produce a valid plan. This is mainly to catch programming errors
PlanSanityChecker.validate(root);

for (PlanOptimizer optimizer : planOptimizers) {
root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator);
requireNonNull(root, format("%s returned a null plan", optimizer.getClass().getName()));
Expand Down
Expand Up @@ -41,6 +41,8 @@
import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer;
import com.facebook.presto.sql.planner.optimizations.SimplifyExpressions;
import com.facebook.presto.sql.planner.optimizations.SingleDistinctOptimizer;
import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedInPredicateSubqueryToSemiJoin;
import com.facebook.presto.sql.planner.optimizations.TransformUncorrelatedScalarToJoin;
import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences;
import com.facebook.presto.sql.planner.optimizations.WindowFilterPushDown;
import com.google.common.collect.ImmutableList;
Expand All @@ -66,6 +68,8 @@ public PlanOptimizersFactory(Metadata metadata, SqlParser sqlParser, FeaturesCon
ImmutableList.Builder<PlanOptimizer> builder = ImmutableList.builder();

builder.add(new DesugaringOptimizer(metadata, sqlParser), // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers
new TransformUncorrelatedScalarToJoin(),
new TransformUncorrelatedInPredicateSubqueryToSemiJoin(),
new ImplementSampleAsFilter(),
new CanonicalizeExpressions(),
new SimplifyExpressions(metadata, sqlParser),
Expand Down
Expand Up @@ -25,16 +25,15 @@
import com.facebook.presto.sql.analyzer.Field;
import com.facebook.presto.sql.analyzer.RelationType;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.DeleteNode;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode.DeleteHandle;
Expand Down Expand Up @@ -639,13 +638,13 @@ private PlanBuilder handleSubqueries(PlanBuilder subPlan, Node node, Iterable<Ex

private PlanBuilder handleSubqueries(PlanBuilder builder, Expression expression, Node node)
{
builder = appendSemiJoins(
builder = appendInPredicateApplyNodes(
builder,
analysis.getInPredicateSubqueries(node)
.stream()
.filter(inPredicate -> nodeContains(expression, inPredicate.getValueList()))
.collect(toImmutableSet()));
builder = appendScalarSubqueryJoins(
builder = appendScalarSubqueryApplyNodes(
builder,
analysis.getScalarSubqueries(node)
.stream()
Expand All @@ -654,50 +653,32 @@ private PlanBuilder handleSubqueries(PlanBuilder builder, Expression expression,
return builder;
}

private PlanBuilder appendSemiJoins(PlanBuilder subPlan, Set<InPredicate> inPredicates)
private PlanBuilder appendInPredicateApplyNodes(PlanBuilder subPlan, Set<InPredicate> inPredicates)
{
for (InPredicate inPredicate : inPredicates) {
subPlan = appendSemiJoin(subPlan, inPredicate);
subPlan = appendInPredicateApplyNode(subPlan, inPredicate);
}
return subPlan;
}

/**
* Semijoins are planned as follows:
* 1) SQL constructs that need to be semijoined are extracted during Analysis phase (currently only InPredicates so far)
* 2) Create a new SemiJoinNode that connects the semijoin lookup field with the planned subquery and have it output a new boolean
* symbol for the result of the semijoin.
* 3) Add an entry to the TranslationMap that notes to map the InPredicate into semijoin output symbol
* <p/>
* Currently, we only support semijoins deriving from InPredicates, but we will probably need
* to add support for more SQL constructs in the future.
*/
private PlanBuilder appendSemiJoin(PlanBuilder subPlan, InPredicate inPredicate)
private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate inPredicate)
{
TranslationMap translations = copyTranslations(subPlan);

subPlan = appendProjections(subPlan, ImmutableList.of(inPredicate.getValue()));
Symbol sourceJoinSymbol = subPlan.translate(inPredicate.getValue());

checkState(inPredicate.getValueList() instanceof SubqueryExpression);
SubqueryExpression subqueryExpression = (SubqueryExpression) inPredicate.getValueList();
RelationPlanner relationPlanner = new RelationPlanner(analysis, symbolAllocator, idAllocator, metadata, session);
RelationPlan valueListRelation = relationPlanner.process(subqueryExpression.getQuery(), null);
Symbol filteringSourceJoinSymbol = getOnlyElement(valueListRelation.getRoot().getOutputSymbols());
RelationPlan valueListRelation = createRelationPlan((SubqueryExpression) inPredicate.getValueList());

Symbol semiJoinOutputSymbol = symbolAllocator.newSymbol("semijoinresult", BOOLEAN);
TranslationMap translationMap = copyTranslations(subPlan);
QualifiedNameReference valueList = getOnlyElement(valueListRelation.getOutputSymbols()).toQualifiedNameReference();
translationMap.setExpressionAsAlreadyTranslated(valueList);
translationMap.put(inPredicate, new InPredicate(inPredicate.getValue(), valueList));

translations.put(inPredicate, semiJoinOutputSymbol);

return new PlanBuilder(translations,
new SemiJoinNode(idAllocator.getNextId(),
return new PlanBuilder(translationMap,
// TODO handle correlation
new ApplyNode(idAllocator.getNextId(),
subPlan.getRoot(),
valueListRelation.getRoot(),
sourceJoinSymbol,
filteringSourceJoinSymbol,
semiJoinOutputSymbol,
Optional.empty(),
Optional.empty()),
ImmutableList.of()),
subPlan.getSampleWeight());
}

Expand All @@ -708,46 +689,43 @@ private TranslationMap copyTranslations(PlanBuilder subPlan)
return translations;
}

private RelationPlan createRelationPlan(SubqueryExpression subqueryExpression)
{
return new RelationPlanner(analysis, symbolAllocator, idAllocator, metadata, session)
.process(subqueryExpression.getQuery(), null);
}

private PlanBuilder appendScalarSubqueryJoins(PlanBuilder builder, Set<SubqueryExpression> scalarSubqueries)
private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set<SubqueryExpression> scalarSubqueries)
{
for (SubqueryExpression scalarSubquery : scalarSubqueries) {
builder = appendScalarSubqueryJoin(builder, scalarSubquery);
builder = appendScalarSubqueryApplyNode(builder, scalarSubquery);
}
return builder;
}

private PlanBuilder appendScalarSubqueryJoin(PlanBuilder builder, SubqueryExpression scalarSubquery)
private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder builder, SubqueryExpression scalarSubquery)
{
EnforceSingleRowNode enforceSingleRowNode = new EnforceSingleRowNode(idAllocator.getNextId(), createRelationPlan(scalarSubquery).getRoot());

TranslationMap translations = copyTranslations(builder);
translations.put(scalarSubquery, getOnlyElement(enforceSingleRowNode.getOutputSymbols()));

// Cross join current (root) relation with subquery
PlanNode root = builder.getRoot();
if (root.getOutputSymbols().isEmpty()) {
// there is nothing to join with - e.g. SELECT (SELECT 1)
return new PlanBuilder(translations, enforceSingleRowNode, builder.getSampleWeight());
}
else {
return new PlanBuilder(translations,
new JoinNode(idAllocator.getNextId(),
JoinNode.Type.FULL,
// TODO handle parameter list
new ApplyNode(idAllocator.getNextId(),
root,
enforceSingleRowNode,
ImmutableList.of(),
Optional.empty(),
Optional.empty()),
ImmutableList.of()),
builder.getSampleWeight());
}
}

private RelationPlan createRelationPlan(SubqueryExpression subqueryExpression)
{
return new RelationPlanner(analysis, symbolAllocator, idAllocator, metadata, session)
.process(subqueryExpression.getQuery(), null);
}

private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node, List<Expression> outputs, List<Expression> orderBy)
{
if (node.getSelect().isDistinct()) {
Expand Down
Expand Up @@ -26,13 +26,13 @@
import com.facebook.presto.sql.analyzer.RelationType;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.IntersectNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SampleNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.planner.plan.UnnestNode;
Expand Down Expand Up @@ -82,7 +82,6 @@
import java.util.Set;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static com.facebook.presto.sql.tree.Join.Type.INNER;
Expand Down Expand Up @@ -276,8 +275,8 @@ else if (firstDependencies.stream().allMatch(right.canResolvePredicate()) && sec

// Add semi joins if necessary
if (joinInPredicates != null) {
leftPlanBuilder = appendSemiJoins(leftPlanBuilder, joinInPredicates.getLeftInPredicates());
rightPlanBuilder = appendSemiJoins(rightPlanBuilder, joinInPredicates.getRightInPredicates());
leftPlanBuilder = appendInPredicateApplyNodes(leftPlanBuilder, joinInPredicates.getLeftInPredicates());
rightPlanBuilder = appendInPredicateApplyNodes(rightPlanBuilder, joinInPredicates.getRightInPredicates());
}

// Add projections for join criteria
Expand Down Expand Up @@ -328,7 +327,7 @@ else if (firstDependencies.stream().allMatch(right.canResolvePredicate()) && sec
translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations());
translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations());
PlanBuilder rootPlanBuilder = new PlanBuilder(translationMap, root, sampleWeight);
rootPlanBuilder = appendSemiJoins(rootPlanBuilder, analysis.getInPredicateSubqueries(node));
rootPlanBuilder = appendInPredicateApplyNodes(rootPlanBuilder, analysis.getInPredicateSubqueries(node));
for (Expression expression : complexJoinExpressions) {
postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression));
}
Expand Down Expand Up @@ -682,41 +681,34 @@ private PlanBuilder appendProjections(PlanBuilder subPlan, Iterable<Expression>
return new PlanBuilder(translations, new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), projections.build()), subPlan.getSampleWeight());
}

private PlanBuilder appendSemiJoins(PlanBuilder subPlan, Iterable<InPredicate> inPredicates)
private PlanBuilder appendInPredicateApplyNodes(PlanBuilder subPlan, Iterable<InPredicate> inPredicates)
{
for (InPredicate inPredicate : inPredicates) {
subPlan = appendSemiJoin(subPlan, inPredicate);
subPlan = appendInPredicateApplyNode(subPlan, inPredicate);
}
return subPlan;
}

private PlanBuilder appendSemiJoin(PlanBuilder subPlan, InPredicate inPredicate)
private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate inPredicate)
{
TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis);
translations.copyMappingsFrom(subPlan.getTranslations());

subPlan = appendProjections(subPlan, ImmutableList.of(inPredicate.getValue()));
Symbol sourceJoinSymbol = subPlan.translate(inPredicate.getValue());

checkState(inPredicate.getValueList() instanceof SubqueryExpression);
SubqueryExpression subqueryExpression = (SubqueryExpression) inPredicate.getValueList();
RelationPlanner relationPlanner = new RelationPlanner(analysis, symbolAllocator, idAllocator, metadata, session);
RelationPlan valueListRelation = relationPlanner.process(subqueryExpression.getQuery(), null);
Symbol filteringSourceJoinSymbol = Iterables.getOnlyElement(valueListRelation.getRoot().getOutputSymbols());
PlanNode valueListRelation = relationPlanner.process(subqueryExpression.getQuery(), null).getRoot();

Symbol semiJoinOutputSymbol = symbolAllocator.newSymbol("semijoinresult", BOOLEAN);
subPlan = appendProjections(subPlan, ImmutableList.of(inPredicate.getValue()));

translations.put(inPredicate, semiJoinOutputSymbol);
TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis);
translations.copyMappingsFrom(subPlan.getTranslations());
QualifiedNameReference valueList = Iterables.getOnlyElement(valueListRelation.getOutputSymbols()).toQualifiedNameReference();
translations.setExpressionAsAlreadyTranslated(valueList);
translations.put(inPredicate, new InPredicate(inPredicate.getValue(), valueList));

return new PlanBuilder(translations,
new SemiJoinNode(idAllocator.getNextId(),
new ApplyNode(idAllocator.getNextId(),
subPlan.getRoot(),
valueListRelation.getRoot(),
sourceJoinSymbol,
filteringSourceJoinSymbol,
semiJoinOutputSymbol,
Optional.empty(),
Optional.empty()),
valueListRelation,
ImmutableList.of()),
subPlan.getSampleWeight());
}

Expand Down

0 comments on commit b5d8bf5

Please sign in to comment.