Skip to content

Commit

Permalink
Pass proper captures to members of RuleSet
Browse files Browse the repository at this point in the history
  • Loading branch information
kokosing committed Sep 7, 2017
1 parent 435d259 commit 378a0d2
Showing 1 changed file with 22 additions and 6 deletions.
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Match;
import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.security.AccessControl; import com.facebook.presto.security.AccessControl;
Expand All @@ -35,6 +36,7 @@


import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.empty; import static java.util.Optional.empty;
import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toSet;


Expand Down Expand Up @@ -96,7 +98,7 @@ public ConnectorId getCurrentConnectorId()
return queryRunner.inTransaction(transactionSession -> metadata.getCatalogHandle(transactionSession, session.getCatalog().get())).get(); return queryRunner.inTransaction(transactionSession -> metadata.getCatalogHandle(transactionSession, session.getCatalog().get())).get();
} }


private class RuleSetAdapter private static class RuleSetAdapter
implements Rule<PlanNode> implements Rule<PlanNode>
{ {
private final RuleSet ruleSet; private final RuleSet ruleSet;
Expand All @@ -116,19 +118,33 @@ public Pattern<PlanNode> getPattern()
public Optional<PlanNode> apply(PlanNode node, Captures captures, Context context) public Optional<PlanNode> apply(PlanNode node, Captures captures, Context context)
{ {
PlanNodeMatcher planNodeMatcher = new PlanNodeMatcher(context.getLookup()); PlanNodeMatcher planNodeMatcher = new PlanNodeMatcher(context.getLookup());
Set<Rule> matching = ruleSet.rules().stream() Set<RuleMatch> matching = ruleSet.rules().stream()
.filter(rule -> matches(rule.getPattern(), node, planNodeMatcher)) .map(rule -> new RuleMatch(rule, planNodeMatcher.match(rule.getPattern(), node)))
.filter(ruleMatch -> ruleMatch.match.isPresent())
.collect(toSet()); .collect(toSet());

if (matching.size() == 0) { if (matching.size() == 0) {
return empty(); return empty();
} }


return getOnlyElement(matching).apply(node, captures, context); return getOnlyElement(matching).apply(context);
} }


private boolean matches(Pattern pattern, PlanNode node, PlanNodeMatcher planNodeMatcher) private static class RuleMatch<T>
{ {
return planNodeMatcher.match(pattern, node).isPresent(); private final Rule<T> rule;
private final Match<T> match;

private RuleMatch(Rule<T> rule, Match<T> match)
{
this.rule = requireNonNull(rule, "rule is null");
this.match = requireNonNull(match, "match is null");
}

private Optional<PlanNode> apply(Context context)
{
return rule.apply(match.value(), match.captures(), context);
}
} }
} }
} }

0 comments on commit 378a0d2

Please sign in to comment.