Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding agg-to-semi-join rule #1

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private RelOptRules() {
CoreRules.JOIN_CONDITION_PUSH,
AbstractConverter.ExpandConversionRule.INSTANCE,
CoreRules.JOIN_COMMUTE,
CoreRules.AGGREGATE_TO_SEMI_JOIN,
CoreRules.PROJECT_TO_SEMI_JOIN,
CoreRules.JOIN_ON_UNIQUE_TO_SEMI_JOIN,
CoreRules.JOIN_TO_SEMI_JOIN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ private CoreRules() {}
public static final AggregateJoinTransposeRule AGGREGATE_JOIN_TRANSPOSE_EXTENDED =
AggregateJoinTransposeRule.Config.EXTENDED.toRule();

/** Rule that creates a {@link Join#isSemiJoin semi-join} from a
* {@link Aggregate} on top of a {@link Join} with an {@link Aggregate} as its
* right input.
*
* @see #JOIN_TO_SEMI_JOIN */
public static final SemiJoinRule.AggregateToSemiJoinRule AGGREGATE_TO_SEMI_JOIN =
SemiJoinRule.AggregateToSemiJoinRule.AggregateToSemiJoinRuleConfig.DEFAULT.toRule();

/** Rule that pushes an {@link Aggregate}
* past a non-distinct {@link Union}. */
public static final AggregateUnionTransposeRule AGGREGATE_UNION_TRANSPOSE =
Expand Down
83 changes: 77 additions & 6 deletions core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ protected SemiJoinRule(Config config) {
super(config);
}

protected void perform(RelOptRuleCall call, @Nullable Project project,
protected void perform(RelOptRuleCall call, @Nullable RelNode topRel,
Join join, RelNode left, Aggregate aggregate) {
final RelOptCluster cluster = join.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
if (project != null) {
final ImmutableBitSet bits =
RelOptUtil.InputFinder.bits(project.getProjects(), null);
if (topRel != null) {
final ImmutableBitSet bits = findBits(topRel);
final ImmutableBitSet rightBits =
ImmutableBitSet.range(left.getRowType().getFieldCount(),
join.getRowType().getFieldCount());
Expand Down Expand Up @@ -123,13 +122,85 @@ protected void perform(RelOptRuleCall call, @Nullable Project project,
default:
throw new AssertionError(join.getJoinType());
}
if (project != null) {
relBuilder.project(project.getProjects(), project.getRowType().getFieldNames());
if (topRel != null) {
if (topRel instanceof Project) {
Project topProject = (Project) topRel;
relBuilder.project(topProject.getProjects(), topProject.getRowType().getFieldNames());
} else if (topRel instanceof Aggregate) {
Aggregate topAgg = (Aggregate) topRel;
relBuilder.aggregate(relBuilder.groupKey(topAgg.getGroupSet()), topAgg.getAggCallList());
}
}
final RelNode relNode = relBuilder.build();
call.transformTo(relNode);
}

private static ImmutableBitSet findBits(RelNode topRel) {
if (topRel instanceof Project) {
Project project = (Project) topRel;
return RelOptUtil.InputFinder.bits(project.getProjects(), null);
} else if (topRel instanceof Aggregate) {
Aggregate aggregate = (Aggregate) topRel;
return ImmutableBitSet.of(RelOptUtil.getAllFields(aggregate));
Comment on lines +143 to +144
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be sufficient to detect whether field access is only on the left?

} else {
return ImmutableBitSet.of();
}
}

/** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate
* as its right child.
*
* @see CoreRules#AGGREGATE_TO_SEMI_JOIN */
public static class AggregateToSemiJoinRule extends SemiJoinRule {
/** Creates a AggregateToSemiJoinRule. */
protected AggregateToSemiJoinRule(AggregateToSemiJoinRuleConfig config) {
super(config);
}

@Deprecated // to be removed before 2.0
public AggregateToSemiJoinRule(Class<Aggregate> topAggClass,
Class<Join> joinClass, Class<Aggregate> rightAggClass,
RelBuilderFactory relBuilderFactory, String description) {
this(AggregateToSemiJoinRuleConfig.DEFAULT.withRelBuilderFactory(relBuilderFactory)
.withDescription(description)
.as(AggregateToSemiJoinRuleConfig.class)
.withOperandFor(topAggClass, joinClass, rightAggClass));
}

@Override public void onMatch(RelOptRuleCall call) {
final Aggregate topAgg = call.rel(0);
final Join join = call.rel(1);
final RelNode left = call.rel(2);
final Aggregate rightAgg = call.rel(3);
perform(call, topAgg, join, left, rightAgg);
}

/** Rule configuration. */
@Value.Immutable
public interface AggregateToSemiJoinRuleConfig extends SemiJoinRule.Config {
AggregateToSemiJoinRuleConfig DEFAULT = ImmutableAggregateToSemiJoinRuleConfig.of()
.withDescription("SemiJoinRule:aggregate")
.withOperandFor(Aggregate.class, Join.class, Aggregate.class);

@Override default AggregateToSemiJoinRule toRule() {
return new AggregateToSemiJoinRule(this);
}

/** Defines an operand tree for the given classes. */
default AggregateToSemiJoinRuleConfig withOperandFor(Class<? extends Aggregate> topAggClass,
Class<? extends Join> joinClass,
Class<? extends Aggregate> rightAggClass) {
return withOperandSupplier(b ->
b.operand(topAggClass).oneInput(b2 ->
b2.operand(joinClass)
.predicate(SemiJoinRule::isJoinTypeSupported).inputs(
b3 -> b3.operand(RelNode.class).anyInputs(),
b4 -> b4.operand(rightAggClass).anyInputs())))
.as(AggregateToSemiJoinRuleConfig.class);
}
}
}

/** SemiJoinRule that matches a Project on top of a Join with an Aggregate
* as its right child.
*
Expand Down
1 change: 1 addition & 0 deletions core/src/main/java/org/apache/calcite/tools/Programs.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class Programs {
EnumerableRules.ENUMERABLE_VALUES_RULE,
EnumerableRules.ENUMERABLE_WINDOW_RULE,
EnumerableRules.ENUMERABLE_MATCH_RULE,
CoreRules.AGGREGATE_TO_SEMI_JOIN,
CoreRules.PROJECT_TO_SEMI_JOIN,
CoreRules.JOIN_ON_UNIQUE_TO_SEMI_JOIN,
CoreRules.JOIN_TO_SEMI_JOIN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ public final class SortRemoveRuleTest {
@Test void removeSortOverEnumerableSemiJoin() throws Exception {
RuleSet prepareRules =
RuleSets.ofList(CoreRules.SORT_PROJECT_TRANSPOSE,
CoreRules.AGGREGATE_TO_SEMI_JOIN,
CoreRules.PROJECT_TO_SEMI_JOIN,
CoreRules.JOIN_TO_SEMI_JOIN,
EnumerableRules.ENUMERABLE_PROJECT_RULE,
Expand Down
27 changes: 26 additions & 1 deletion core/src/test/java/org/apache/calcite/test/InterpreterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ private static void assertRows(Interpreter interpreter,
.returnsRows("[1, a]", "[3, c]");
}

@Test void testInterpretSemiJoin() {
@Test void testInterpretProjectSemiJoin() {
final String sql = "select x, y from (values (1, 'a'), (2, 'b'), (3, 'c')) as t(x, y)\n"
+ "where x in\n"
+ "(select x from (values (1, 'd'), (3, 'g')) as t2(x, y))";
Expand All @@ -537,6 +537,31 @@ private static void assertRows(Interpreter interpreter,
}
}

@Test void testInterpretAggregateSemiJoin() {
final String sql = "select y, sum(x) from (values (1, 'a'), (2, 'a'), (3, 'c')) as t(x, y)\n"
+ "where x in\n"
+ "(select x from (values (1, 'd'), (2, 'a'), (3, 'g')) as t2(x, y))\n"
+ "group by y";
try (Planner planner =
sql(sql).withSqlToRel(c -> c.withExpand(true)).createPlanner()) {
SqlNode validate = planner.validate(planner.parse(sql));
RelNode convert = planner.rel(validate).rel;
final HepProgram program = new HepProgramBuilder()
.addRuleInstance(CoreRules.AGGREGATE_TO_SEMI_JOIN)
.build();
final HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(convert);
final RelNode relNode = hepPlanner.findBestExp();
final MyDataContext dataContext =
new MyDataContext(rootSchema, relNode);
assertInterpret(relNode, dataContext, true, "[a, 3]", "[c, 3]");
} catch (ValidationException
| SqlParseException
| RelConversionException e) {
throw Util.throwAsRuntime(e);
}
}

@Test void testInterpretAntiJoin() {
final String sql = "select x, y from (values (1, 'a'), (2, 'b'), (3, 'c')) as t(x, y)\n"
+ "where x not in\n"
Expand Down
28 changes: 26 additions & 2 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}

@Test void testSemiJoinRule() {
@Test void testProjectSemiJoinRule() {
final String sql = "select dept.* from dept join (\n"
+ " select distinct deptno from emp\n"
+ " where sal > 100) using (deptno)";
Expand All @@ -1221,6 +1221,23 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}

@Test void testAggregateSemiJoinRule() {
final String sql = "select dept.deptno, count(*)\n"
+ "from dept join (\n"
+ " select distinct deptno from emp\n"
+ " where sal > 100) using (deptno)\n"
+ "group by dept.deptno";
sql(sql)
.withDecorrelate(true)
.withTrim(true)
.withPreRule(CoreRules.FILTER_PROJECT_TRANSPOSE,
CoreRules.AGGREGATE_PROJECT_MERGE,
CoreRules.FILTER_INTO_JOIN,
CoreRules.PROJECT_MERGE)
.withRule(CoreRules.AGGREGATE_TO_SEMI_JOIN)
.check();
}

@Test void testSemiJoinRuleDoNotMatchAggregate() {
final String sql = "select *\n"
+ "from emp\n"
Expand Down Expand Up @@ -1293,7 +1310,7 @@ private void checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.checkUnchanged();
}

/** Similar to {@link #testSemiJoinRule()} but LEFT. */
/** Similar to {@link #testProjectSemiJoinRule()} but LEFT. */
@Test void testSemiJoinRuleLeft() {
final String sql = "select name from dept left join (\n"
+ " select distinct deptno from emp\n"
Expand Down Expand Up @@ -6070,6 +6087,13 @@ private void checkSwapJoinShouldNotMatch(JoinRelType type) {
checkSemiJoinRuleOnAntiJoin(CoreRules.PROJECT_TO_SEMI_JOIN);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-4621">[CALCITE-4621]
* SemiJoinRule throws AssertionError on ANTI join</a>. */
@Test void testAggregateToSemiJoinRuleOnAntiJoin() {
checkSemiJoinRuleOnAntiJoin(CoreRules.AGGREGATE_TO_SEMI_JOIN);
}

private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
final Function<RelBuilder, RelNode> relFn = b -> b
.scan("DEPT")
Expand Down
96 changes: 69 additions & 27 deletions core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,48 @@ LogicalProject(MGR=[$0], SUM_SAL=[$2])
LogicalAggregate(group=[{0, 1}], SUM_SAL=[SUM($2)])
LogicalProject(MGR=[$3], DEPTNO=[$7], SAL=[$5])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testAggregateSemiJoinRule">
<Resource name="sql">
<![CDATA[select dept.deptno, count(*)
from dept join (
select distinct deptno from emp
where sal > 100) using (deptno)
group by dept.deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
LogicalJoin(condition=[=($0, $1)], joinType=[inner])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalAggregate(group=[{7}])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
LogicalJoin(condition=[=($0, $8)], joinType=[semi])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testAggregateToSemiJoinRuleOnAntiJoin">
<Resource name="planBefore">
<![CDATA[
LogicalProject(DNAME=[$1])
LogicalJoin(condition=[=($0, $3)], joinType=[anti])
LogicalTableScan(table=[[scott, DEPT]])
LogicalAggregate(group=[{0}])
LogicalProject(DEPTNO=[$7])
LogicalTableScan(table=[[scott, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -7523,6 +7565,33 @@ from (
LogicalProject(EXPR$0=[ROW_NUMBER() OVER (ORDER BY $0)], COL1=[$1])
LogicalProject(DEPTNO=[$7], COL1=[SUM(100) OVER (PARTITION BY $7 ORDER BY $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testProjectSemiJoinRule">
<Resource name="sql">
<![CDATA[select dept.* from dept join (
select distinct deptno from emp
where sal > 100) using (deptno)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(DEPTNO=[$0], NAME=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalAggregate(group=[{0}])
LogicalProject(DEPTNO=[$7])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalJoin(condition=[=($0, $2)], joinType=[semi])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalProject(DEPTNO=[$7])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -12906,33 +12975,6 @@ LogicalProject(SAL=[$0])
LogicalFilter(condition=[=($0, 100)])
LogicalProject(SAL=[$5], DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testSemiJoinRule">
<Resource name="sql">
<![CDATA[select dept.* from dept join (
select distinct deptno from emp
where sal > 100) using (deptno)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(DEPTNO=[$0], NAME=[$1])
LogicalJoin(condition=[=($0, $2)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalAggregate(group=[{0}])
LogicalProject(DEPTNO=[$7])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalJoin(condition=[=($0, $2)], joinType=[semi])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalProject(DEPTNO=[$7])
LogicalFilter(condition=[>($5, 100)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down