Skip to content

Commit

Permalink
Add support for LEFT/RIGHT/FULL/INNER lateral join
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi authored and martint committed Apr 1, 2019
1 parent c9d187b commit 79a70d1
Show file tree
Hide file tree
Showing 19 changed files with 242 additions and 50 deletions.
Expand Up @@ -59,9 +59,11 @@
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.Intersect;
import io.prestosql.sql.tree.Join;
import io.prestosql.sql.tree.JoinCriteria;
import io.prestosql.sql.tree.JoinUsing;
import io.prestosql.sql.tree.LambdaArgumentDeclaration;
import io.prestosql.sql.tree.Lateral;
import io.prestosql.sql.tree.NaturalJoin;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.Query;
Expand Down Expand Up @@ -89,8 +91,10 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.prestosql.sql.analyzer.SemanticExceptions.notSupportedException;
import static io.prestosql.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.sql.tree.Join.Type.INNER;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -219,9 +223,6 @@ protected RelationPlan visitJoin(Join node, Void context)

Optional<Lateral> lateral = getLateral(node.getRight());
if (lateral.isPresent()) {
if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) {
throw notSupportedException(lateral.get(), "LATERAL on other than the right side of CROSS JOIN");
}
return planLateralJoin(node, leftPlan, lateral.get());
}

Expand Down Expand Up @@ -537,7 +538,47 @@ private RelationPlan planLateralJoin(Join join, RelationPlan leftPlan, Lateral l
PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan);
PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan);

PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(leftPlanBuilder, rightPlanBuilder, lateral.getQuery(), true, LateralJoinNode.Type.INNER);
Expression filterExpression;
if (!join.getCriteria().isPresent()) {
filterExpression = TRUE_LITERAL;
}
else {
JoinCriteria criteria = join.getCriteria().get();
if (criteria instanceof JoinUsing || criteria instanceof NaturalJoin) {
throw notSupportedException(join, "Lateral join with criteria other than ON");
}
filterExpression = (Expression) getOnlyElement(criteria.getNodes());
}

List<Symbol> rewriterOutputSymbols = ImmutableList.<Symbol>builder()
.addAll(leftPlan.getFieldMappings())
.addAll(rightPlan.getFieldMappings())
.build();

// this node is not used in the plan. It is only used for creating the TranslationMap.
PlanNode dummy = new ValuesNode(
idAllocator.getNextId(),
ImmutableList.<Symbol>builder()
.addAll(leftPlanBuilder.getRoot().getOutputSymbols())
.addAll(rightPlanBuilder.getRoot().getOutputSymbols())
.build(),
ImmutableList.of());

RelationPlan intermediateRelationPlan = new RelationPlan(dummy, analysis.getScope(join), rewriterOutputSymbols);
TranslationMap translationMap = new TranslationMap(intermediateRelationPlan, analysis, lambdaDeclarationToSymbolMap);
translationMap.setFieldMappings(rewriterOutputSymbols);
translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations());
translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations());

Expression rewrittenFilterCondition = translationMap.rewrite(filterExpression);

PlanBuilder planBuilder = subqueryPlanner.appendLateralJoin(
leftPlanBuilder,
rightPlanBuilder,
lateral.getQuery(),
true,
LateralJoinNode.Type.typeConvert(join.getType()),
rewrittenFilterCondition);

List<Symbol> outputSymbols = ImmutableList.<Symbol>builder()
.addAll(leftPlan.getRoot().getOutputSymbols())
Expand Down
Expand Up @@ -227,10 +227,10 @@ private PlanBuilder appendScalarSubqueryApplyNode(PlanBuilder subPlan, SubqueryE
}

// The subquery's EnforceSingleRowNode always produces a row, so the join is effectively INNER
return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER);
return appendLateralJoin(subPlan, subqueryPlan, scalarSubquery.getQuery(), correlationAllowed, LateralJoinNode.Type.INNER, TRUE_LITERAL);
}

public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type)
public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, boolean correlationAllowed, LateralJoinNode.Type type, Expression filterCondition)
{
PlanNode subqueryNode = subqueryPlan.getRoot();
Map<Expression, Expression> correlation = extractCorrelation(subPlan, subqueryNode);
Expand All @@ -247,6 +247,7 @@ public PlanBuilder appendLateralJoin(PlanBuilder subPlan, PlanBuilder subqueryPl
subqueryNode,
ImmutableList.copyOf(SymbolsExtractor.extractUnique(correlation.values())),
type,
filterCondition,
query),
analysis.getParameters());
}
Expand Down
Expand Up @@ -21,12 +21,15 @@
import io.prestosql.sql.planner.plan.PlanNode;

import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

public class RemoveUnreferencedScalarLateralNodes
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin();
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Expand Up @@ -22,12 +22,15 @@
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.tree.Expression;

import java.util.Optional;

import static io.prestosql.matching.Pattern.nonEmpty;
import static io.prestosql.sql.ExpressionUtils.combineConjuncts;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

/**
* Tries to decorrelate subquery and rewrite it using normal join.
Expand All @@ -53,18 +56,24 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup());
Optional<DecorrelatedNode> decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, lateralJoinNode.getCorrelation());

return decorrelatedNodeOptional.map(decorrelatedNode ->
Result.ofPlanNode(new JoinNode(
context.getIdAllocator().getNextId(),
lateralJoinNode.getType().toJoinNodeType(),
lateralJoinNode.getInput(),
decorrelatedNode.getNode(),
ImmutableList.of(),
lateralJoinNode.getOutputSymbols(),
decorrelatedNode.getCorrelatedPredicates(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()))).orElseGet(Result::empty);
return decorrelatedNodeOptional
.map(decorrelatedNode -> {
Expression joinFilter = combineConjuncts(
decorrelatedNode.getCorrelatedPredicates().orElse(TRUE_LITERAL),
lateralJoinNode.getFilter());
return Result.ofPlanNode(new JoinNode(
context.getIdAllocator().getNextId(),
lateralJoinNode.getType().toJoinNodeType(),
lateralJoinNode.getInput(),
decorrelatedNode.getNode(),
ImmutableList.of(),
lateralJoinNode.getOutputSymbols(),
joinFilter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(joinFilter),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()));
})
.orElseGet(Result::empty);
}
}
Expand Up @@ -31,7 +31,9 @@
import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.prestosql.util.MorePredicates.isInstanceOfAny;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -67,7 +69,8 @@ public class TransformCorrelatedScalarAggregationToJoin
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(nonEmpty(correlation()));
.with(nonEmpty(correlation()))
.with(filter().equalTo(TRUE_LITERAL)); // todo non-trivial join filter: adding filter/project on top of aggregation

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Expand Up @@ -46,6 +46,7 @@
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

Expand Down Expand Up @@ -81,7 +82,8 @@ public class TransformCorrelatedScalarSubquery
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(nonEmpty(correlation()));
.with(nonEmpty(correlation()))
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern getPattern()
Expand Down Expand Up @@ -116,6 +118,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
rewrittenSubquery,
lateralJoinNode.getCorrelation(),
producesSingleRow ? lateralJoinNode.getType() : LEFT,
lateralJoinNode.getFilter(),
lateralJoinNode.getOriginSubquery()));
}

Expand All @@ -130,6 +133,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
rewrittenSubquery,
lateralJoinNode.getCorrelation(),
LEFT,
lateralJoinNode.getFilter(),
lateralJoinNode.getOriginSubquery());

Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BooleanType.BOOLEAN);
Expand Down
Expand Up @@ -25,7 +25,9 @@
import java.util.List;

import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.filter;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

/**
* This optimizer can rewrite correlated single row subquery to projection in a way described here:
Expand All @@ -47,7 +49,8 @@
public class TransformCorrelatedSingleRowSubqueryToProject
implements Rule<LateralJoinNode>
{
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin();
private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
.with(filter().equalTo(TRUE_LITERAL));

@Override
public Pattern<LateralJoinNode> getPattern()
Expand Down
Expand Up @@ -145,6 +145,7 @@ private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, C
subquery,
applyNode.getCorrelation(),
LEFT,
TRUE_LITERAL,
applyNode.getOriginSubquery()),
assignments.build()));
}
Expand All @@ -171,6 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context)
Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))),
parent.getCorrelation(),
INNER,
TRUE_LITERAL,
parent.getOriginSubquery());
}
}
Expand Up @@ -20,12 +20,14 @@
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.tree.Expression;

import java.util.Optional;

import static io.prestosql.matching.Pattern.empty;
import static io.prestosql.sql.planner.plan.Patterns.LateralJoin.correlation;
import static io.prestosql.sql.planner.plan.Patterns.lateralJoin;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;

public class TransformUncorrelatedLateralToJoin
implements Rule<LateralJoinNode>
Expand All @@ -52,10 +54,19 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context
.addAll(lateralJoinNode.getInput().getOutputSymbols())
.addAll(lateralJoinNode.getSubquery().getOutputSymbols())
.build(),
Optional.empty(),
filter(lateralJoinNode.getFilter()),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty()));
}

private Optional<Expression> filter(Expression lateralJoinFilter)
{
if (lateralJoinFilter.equals(TRUE_LITERAL)) {
return Optional.empty();
}

return Optional.of(lateralJoinFilter);
}
}
Expand Up @@ -86,7 +86,12 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.concat;
import static com.google.common.collect.Sets.intersection;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.INNER;
import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.LateralJoinNode.Type.RIGHT;
import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -811,11 +816,25 @@ public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext<Set<Symb
@Override
public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext<Set<Symbol>> context)
{
PlanNode subquery = context.rewrite(node.getSubquery(), context.get());
Set<Symbol> expectedFilterSymbols = SymbolsExtractor.extractUnique(node.getFilter());

Set<Symbol> expectedFilterAndContextSymbols = ImmutableSet.<Symbol>builder()
.addAll(expectedFilterSymbols)
.addAll(context.get())
.build();

PlanNode subquery = context.rewrite(node.getSubquery(), expectedFilterAndContextSymbols);

// remove unused lateral nodes
if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty() && isScalar(subquery)) {
return context.rewrite(node.getInput(), context.get());
if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty()) {
// remove unused lateral subquery of inner join
if (node.getType() == INNER && isScalar(subquery) && node.getFilter().equals(TRUE_LITERAL)) {
return context.rewrite(node.getInput(), context.get());
}
// remove unused lateral subquery of left join
if (node.getType() == LEFT && isAtMostScalar(subquery)) {
return context.rewrite(node.getInput(), context.get());
}
}

// prune not used correlation symbols
Expand All @@ -824,18 +843,29 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext<Set<Symbol
.filter(subquerySymbols::contains)
.collect(toImmutableList());

Set<Symbol> inputContext = ImmutableSet.<Symbol>builder()
.addAll(context.get())
Set<Symbol> expectedCorrelationAndContextSymbols = ImmutableSet.<Symbol>builder()
.addAll(newCorrelation)
.addAll(context.get())
.build();
Set<Symbol> inputContext = ImmutableSet.<Symbol>builder()
.addAll(expectedCorrelationAndContextSymbols)
.addAll(expectedFilterSymbols)
.build();
PlanNode input = context.rewrite(node.getInput(), inputContext);

// remove unused lateral nodes
if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), inputContext).isEmpty() && isScalar(input)) {
return subquery;
// remove unused input nodes
if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), expectedCorrelationAndContextSymbols).isEmpty()) {
// remove unused input of inner join
if (node.getType() == INNER && isScalar(input) && node.getFilter().equals(TRUE_LITERAL)) {
return subquery;
}
// remove unused input of right join
if (node.getType() == RIGHT && isAtMostScalar(input)) {
return subquery;
}
}

return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getOriginSubquery());
return new LateralJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getFilter(), node.getOriginSubquery());
}
}
}
Expand Up @@ -174,6 +174,7 @@ countNonNullValue, new Aggregation(
subqueryPlan,
node.getCorrelation(),
LateralJoinNode.Type.INNER,
TRUE_LITERAL,
node.getOriginSubquery());

Expression valueComparedToSubquery = rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue);
Expand Down
Expand Up @@ -454,7 +454,7 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext<Void> cont
PlanNode subquery = context.rewrite(node.getSubquery());
List<Symbol> canonicalCorrelation = canonicalizeAndDistinct(node.getCorrelation());

return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType(), node.getOriginSubquery());
return new LateralJoinNode(node.getId(), source, subquery, canonicalCorrelation, node.getType(), canonicalize(node.getFilter()), node.getOriginSubquery());
}

@Override
Expand Down

0 comments on commit 79a70d1

Please sign in to comment.