Skip to content

Commit

Permalink
Convert left correlated join to inner join during transformation
Browse files Browse the repository at this point in the history
It's possible to simplify outer join to inner
join when subquery side is at least scalar.
This makes sure dynamic filters are propagated
to left join side and more efficient (cross join)
implementaiton is used.
  • Loading branch information
sopel39 committed Nov 27, 2023
1 parent 69759e6 commit 5696215
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
Expand All @@ -36,6 +37,7 @@
import static com.google.common.base.Preconditions.checkState;
import static io.trino.matching.Pattern.empty;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.FULL;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT;
Expand Down Expand Up @@ -64,7 +66,8 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
return Result.ofPlanNode(rewriteToJoin(
correlatedJoinNode,
correlatedJoinNode.getType().toJoinNodeType(),
correlatedJoinNode.getFilter()));
correlatedJoinNode.getFilter(),
context.getLookup()));
}

checkState(
Expand All @@ -79,7 +82,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
else {
type = JoinNode.Type.LEFT;
}
JoinNode joinNode = rewriteToJoin(correlatedJoinNode, type, TRUE_LITERAL);
JoinNode joinNode = rewriteToJoin(correlatedJoinNode, type, TRUE_LITERAL, context.getLookup());

if (correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) {
return Result.ofPlanNode(joinNode);
Expand Down Expand Up @@ -109,8 +112,12 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
return Result.empty();
}

private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinNode.Type type, Expression filter)
private JoinNode rewriteToJoin(CorrelatedJoinNode parent, JoinNode.Type type, Expression filter, Lookup lookup)
{
if (type == JoinNode.Type.LEFT && extractCardinality(parent.getSubquery(), lookup).isAtLeastScalar() && filter.equals(TRUE_LITERAL)) {
// input rows will always be matched against subquery rows
type = JoinNode.Type.INNER;
}
return new JoinNode(
parent.getId(),
type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ public List<Symbol> getOutputSymbols()
.build();
}

public CorrelatedJoinNode withType(Type type)
{
return new CorrelatedJoinNode(getId(), input, subquery, correlation, type, filter, originSubquery);
}

@Override
public PlanNode replaceChildren(List<PlanNode> newChildren)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.FULL;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.RIGHT;
import static io.trino.sql.planner.plan.JoinNode.Type;
Expand All @@ -36,6 +35,26 @@
public class TestTransformUncorrelatedSubqueryToJoin
extends BaseRuleTest
{
@Test
public void testRewriteLeftCorrelatedJoinWithScalarSubquery()
{
tester().assertThat(new TransformUncorrelatedSubqueryToJoin())
.on(p -> {
Symbol a = p.symbol("a");
Symbol b = p.symbol("b");
return p.correlatedJoin(
emptyList(),
p.values(a),
LEFT,
TRUE_LITERAL,
p.values(1, b));
})
.matches(
join(Type.INNER, builder -> builder
.left(values("a"))
.right(values("b"))));
}

@Test
public void testRewriteInnerCorrelatedJoin()
{
Expand All @@ -46,15 +65,15 @@ public void testRewriteInnerCorrelatedJoin()
return p.correlatedJoin(
emptyList(),
p.values(a),
INNER,
LEFT,
new ComparisonExpression(
GREATER_THAN,
b.toSymbolReference(),
a.toSymbolReference()),
p.values(b));
})
.matches(
join(Type.INNER, builder -> builder
join(Type.LEFT, builder -> builder
.filter("b > a")
.left(values("a"))
.right(values("b"))));
Expand Down

0 comments on commit 5696215

Please sign in to comment.