Skip to content

Commit

Permalink
Push TopN through outer Join
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi authored and kokosing committed Mar 26, 2019
1 parent bd14140 commit ec57ed6
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughUnion;
import io.prestosql.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId;
import io.prestosql.sql.planner.iterative.rule.PushTableWriteThroughUnion;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughUnion;
import io.prestosql.sql.planner.iterative.rule.RemoveEmptyDelete;
Expand Down Expand Up @@ -438,6 +439,7 @@ public PlanOptimizers(
ImmutableSet.of(
new CreatePartialTopN(),
new PushTopNThroughProject(),
new PushTopNThroughOuterJoin(),
new PushTopNThroughUnion())));
builder.add(new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Range;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.TopNNode;

import java.util.List;

import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.prestosql.sql.planner.plan.JoinNode.Type.FULL;
import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT;
import static io.prestosql.sql.planner.plan.JoinNode.Type.RIGHT;
import static io.prestosql.sql.planner.plan.Patterns.Join.type;
import static io.prestosql.sql.planner.plan.Patterns.TopN.step;
import static io.prestosql.sql.planner.plan.Patterns.join;
import static io.prestosql.sql.planner.plan.Patterns.source;
import static io.prestosql.sql.planner.plan.Patterns.topN;
import static io.prestosql.sql.planner.plan.TopNNode.Step.PARTIAL;

/**
* Transforms:
* <pre>
* - TopN (partial)
* - Join (left, right or full)
* - left source
* - right source
* </pre>
* Into:
* <pre>
* - Join
* - TopN (present if Join is left or outer, not already limited, and orderBy symbols come from left source)
* - left source
* - TopN (present if Join is right or outer, not already limited, and orderBy symbols come from right source)
* - right source
* </pre>
*/
public class PushTopNThroughOuterJoin
implements Rule<TopNNode>
{
private static final Capture<JoinNode> JOIN_CHILD = newCapture();

private static final Pattern<TopNNode> PATTERN =
topN().with(step().equalTo(PARTIAL))
.with(source().matching(
join().capturedAs(JOIN_CHILD).with(type().matching(type -> type == LEFT || type == RIGHT || type == FULL))));

@Override
public Pattern<TopNNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TopNNode parent, Captures captures, Context context)
{
JoinNode joinNode = captures.get(JOIN_CHILD);

List<Symbol> orderBySymbols = parent.getOrderingScheme().getOrderBy();

PlanNode left = joinNode.getLeft();
PlanNode right = joinNode.getRight();
JoinNode.Type type = joinNode.getType();

if ((type == LEFT || type == FULL)
&& ImmutableSet.copyOf(left.getOutputSymbols()).containsAll(orderBySymbols)
&& !isLimited(left, context.getLookup(), parent.getCount())) {
return Result.ofPlanNode(
joinNode.replaceChildren(ImmutableList.of(
parent.replaceChildren(ImmutableList.of(left)),
right)));
}

if ((type == RIGHT || type == FULL)
&& ImmutableSet.copyOf(right.getOutputSymbols()).containsAll(orderBySymbols)
&& !isLimited(right, context.getLookup(), parent.getCount())) {
return Result.ofPlanNode(
joinNode.replaceChildren(ImmutableList.of(
left,
parent.replaceChildren(ImmutableList.of(right)))));
}

return Result.empty();
}

private static boolean isLimited(PlanNode node, Lookup lookup, long limit)
{
Range<Long> cardinality = extractCardinality(node, lookup);
return cardinality.hasUpperBound() && cardinality.upperEndpoint() <= limit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.ValuesNode;

import static com.google.common.collect.Iterables.getOnlyElement;
Expand Down Expand Up @@ -143,13 +144,23 @@ public Range<Long> visitValues(ValuesNode node, Void context)
@Override
public Range<Long> visitLimit(LimitNode node, Void context)
{
Range<Long> sourceCardinalityRange = node.getSource().accept(this, null);
long upper = node.getCount();
return applyLimit(node.getSource(), node.getCount());
}

@Override
public Range<Long> visitTopN(TopNNode node, Void context)
{
return applyLimit(node.getSource(), node.getCount());
}

private Range<Long> applyLimit(PlanNode source, long limit)
{
Range<Long> sourceCardinalityRange = source.accept(this, null);
if (sourceCardinalityRange.hasUpperBound()) {
upper = min(sourceCardinalityRange.upperEndpoint(), node.getCount());
limit = min(sourceCardinalityRange.upperEndpoint(), limit);
}
long lower = min(upper, sourceCardinalityRange.lowerEndpoint());
return Range.closed(lower, upper);
long lower = min(limit, sourceCardinalityRange.lowerEndpoint());
return Range.closed(lower, limit);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.tests.QueryTemplate;
Expand Down Expand Up @@ -77,6 +78,7 @@
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictTableScan;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values;
import static io.prestosql.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.prestosql.sql.planner.plan.AggregationNode.Step.FINAL;
Expand All @@ -92,6 +94,7 @@
import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER;
import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT;
import static io.prestosql.sql.tree.SortItem.NullOrdering.LAST;
import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING;
import static io.prestosql.sql.tree.SortItem.Ordering.DESCENDING;
import static io.prestosql.tests.QueryTemplate.queryTemplate;
import static io.prestosql.util.MorePredicates.isInstanceOfAny;
Expand Down Expand Up @@ -255,6 +258,22 @@ public void testJoinWithOrderBySameKey()
tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))));
}

@Test
public void testTopNPushdownToJoinSource()
{
assertPlan("SELECT n.name, r.name FROM nation n LEFT JOIN region r ON n.regionkey = r.regionkey ORDER BY n.comment LIMIT 1",
anyTree(
project(
topN(1, ImmutableList.of(sort("N_COMM", ASCENDING, LAST)), TopNNode.Step.FINAL,
anyTree(
join(LEFT, ImmutableList.of(equiJoinClause("N_KEY", "R_KEY")),
project(
topN(1, ImmutableList.of(sort("N_COMM", ASCENDING, LAST)), TopNNode.Step.PARTIAL,
tableScan("nation", ImmutableMap.of("N_NAME", "name", "N_KEY", "regionkey", "N_COMM", "comment")))),
anyTree(
tableScan("region", ImmutableMap.of("R_NAME", "name", "R_KEY", "regionkey")))))))));
}

@Test
public void testUncorrelatedSubqueries()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ public static PlanMatchPattern sort(List<Ordering> orderBy, PlanMatchPattern sou

public static PlanMatchPattern topN(long count, List<Ordering> orderBy, PlanMatchPattern source)
{
return node(TopNNode.class, source).with(new TopNMatcher(count, orderBy));
return topN(count, orderBy, TopNNode.Step.SINGLE, source);
}

public static PlanMatchPattern topN(long count, List<Ordering> orderBy, TopNNode.Step step, PlanMatchPattern source)
{
return node(TopNNode.class, source).with(new TopNMatcher(count, orderBy, step));
}

public static PlanMatchPattern output(PlanMatchPattern source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.prestosql.sql.planner.assertions.PlanMatchPattern.Ordering;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNNode.Step;

import java.util.List;

Expand All @@ -35,11 +36,13 @@ public class TopNMatcher
{
private final long count;
private final List<Ordering> orderBy;
private final Step step;

public TopNMatcher(long count, List<Ordering> orderBy)
public TopNMatcher(long count, List<Ordering> orderBy, Step step)
{
this.count = count;
this.orderBy = ImmutableList.copyOf(requireNonNull(orderBy, "orderBy is null"));
this.step = requireNonNull(step, "step is null");
}

@Override
Expand All @@ -62,6 +65,10 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses
return NO_MATCH;
}

if (topNNode.getStep() != step) {
return NO_MATCH;
}

return match();
}

Expand All @@ -71,6 +78,7 @@ public String toString()
return toStringHelper(this)
.add("count", count)
.add("orderBy", orderBy)
.add("step", step)
.toString();
}
}

0 comments on commit ec57ed6

Please sign in to comment.