Skip to content

Commit

Permalink
Remove DISTINCT inside IN subquery expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 authored and kokosing committed Apr 1, 2019
1 parent 377fca0 commit 79083f7
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 2 deletions.
Expand Up @@ -90,6 +90,7 @@
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughOuterJoin;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughProject;
import io.prestosql.sql.planner.iterative.rule.PushTopNThroughUnion;
import io.prestosql.sql.planner.iterative.rule.RemoveAggregationInSemiJoin;
import io.prestosql.sql.planner.iterative.rule.RemoveEmptyDelete;
import io.prestosql.sql.planner.iterative.rule.RemoveFullSample;
import io.prestosql.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
Expand Down Expand Up @@ -355,7 +356,8 @@ public PlanOptimizers(
ImmutableSet.of(
new InlineProjections(),
new RemoveRedundantIdentityProjections(),
new TransformCorrelatedSingleRowSubqueryToProject())),
new TransformCorrelatedSingleRowSubqueryToProject(),
new RemoveAggregationInSemiJoin())),
new CheckSubqueryNodesAreRewritten(),
predicatePushDown,
new IterativeOptimizer(
Expand Down
@@ -0,0 +1,56 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;

import static com.google.common.collect.Iterables.getOnlyElement;
import static io.prestosql.matching.Capture.newCapture;
import static io.prestosql.sql.planner.plan.Patterns.SemiJoin.getFilteringSource;
import static io.prestosql.sql.planner.plan.Patterns.aggregation;
import static io.prestosql.sql.planner.plan.Patterns.semiJoin;

/**
* Remove the aggregation node that produces distinct rows from the Filtering source of a Semi join.
*/
public class RemoveAggregationInSemiJoin
implements Rule<SemiJoinNode>
{
private static final Capture<AggregationNode> CHILD = newCapture();

private static final Pattern<SemiJoinNode> PATTERN = semiJoin()
.with(getFilteringSource()
.matching(aggregation()
.capturedAs(CHILD).matching(AggregationNode::producesDistinctRows)));

@Override
public Pattern<SemiJoinNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(SemiJoinNode semiJoinNode, Captures captures, Context context)
{
AggregationNode filteringSource = captures.get(CHILD);
return Result.ofPlanNode(semiJoinNode
.replaceChildren(ImmutableList.of(semiJoinNode.getSource(), getOnlyElement(filteringSource.getSources()))));
}
}
Expand Up @@ -275,4 +275,21 @@ public static Property<ValuesNode, Lookup, List<List<Expression>>> rows()
return property("rows", ValuesNode::getRows);
}
}

public static class SemiJoin
{
public static Property<SemiJoinNode, Lookup, PlanNode> getSource()
{
return property(
"source",
(SemiJoinNode semiJoin, Lookup lookup) -> lookup.resolve(semiJoin.getSource()));
}

public static Property<SemiJoinNode, Lookup, PlanNode> getFilteringSource()
{
return property(
"filteringSource",
(SemiJoinNode semiJoin, Lookup lookup) -> lookup.resolve(semiJoin.getFilteringSource()));
}
}
}
Expand Up @@ -430,10 +430,15 @@ public void testJoinOutputPruning()
}

private void assertPlanContainsNoApplyOrAnyJoin(String sql)
{
assertPlanDoesNotContain(sql, ApplyNode.class, JoinNode.class, IndexJoinNode.class, SemiJoinNode.class, LateralJoinNode.class);
}

private void assertPlanDoesNotContain(String sql, Class... classes)
{
assertFalse(
searchFrom(plan(sql, LogicalPlanner.Stage.OPTIMIZED).getRoot())
.where(isInstanceOfAny(ApplyNode.class, JoinNode.class, IndexJoinNode.class, SemiJoinNode.class, LateralJoinNode.class))
.where(isInstanceOfAny(classes))
.matches(),
"Unexpected node for query: " + sql);
}
Expand Down Expand Up @@ -820,4 +825,12 @@ public void testDistributedSort()
tableScan("orders", ImmutableMap.of(
"ORDERKEY", "orderkey")))))));
}

@Test
public void testRemoveAggregationInSemiJoin()
{
assertPlanDoesNotContain(
"SELECT custkey FROM orders WHERE custkey IN (SELECT distinct custkey FROM customer)",
AggregationNode.class);
}
}
@@ -0,0 +1,84 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.plan.PlanNode;
import org.testng.annotations.Test;

import java.util.Optional;

import static io.prestosql.spi.type.BigintType.BIGINT;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values;
import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression;

public class TestRemoveAggregationInSemiJoin
extends BaseRuleTest
{
@Test
public void test()
{
tester().assertThat(new RemoveAggregationInSemiJoin())
.on(TestRemoveAggregationInSemiJoin::semiJoinWithDistinctAsFilteringSource)
.matches(
semiJoin("leftKey", "rightKey", "match",
values("leftKey"),
values("rightKey")));
}

@Test
public void testDoesNotFire()
{
tester().assertThat(new RemoveAggregationInSemiJoin())
.on(TestRemoveAggregationInSemiJoin::semiJoinWithAggregationAsFilteringSource)
.doesNotFire();
}

private static PlanNode semiJoinWithDistinctAsFilteringSource(PlanBuilder p)
{
Symbol leftKey = p.symbol("leftKey");
Symbol rightKey = p.symbol("rightKey");
return p.semiJoin(
leftKey,
rightKey,
p.symbol("match"),
Optional.empty(),
Optional.empty(),
p.values(leftKey),
p.aggregation(builder -> builder
.singleGroupingSet(rightKey)
.source(p.values(rightKey))));
}

private static PlanNode semiJoinWithAggregationAsFilteringSource(PlanBuilder p)
{
Symbol leftKey = p.symbol("leftKey");
Symbol rightKey = p.symbol("rightKey");
return p.semiJoin(
leftKey,
rightKey,
p.symbol("match"),
Optional.empty(),
Optional.empty(),
p.values(leftKey),
p.aggregation(builder -> builder
.globalGrouping()
.addAggregation(rightKey, expression("count(rightValue)"), ImmutableList.of(BIGINT))
.source(p.values(p.symbol("rightValue")))));
}
}

0 comments on commit 79083f7

Please sign in to comment.