From 378a0d26648a6b3d743313a1609344c267cd97cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Grzegorz=20Kokosi=C5=84ski?= Date: Wed, 30 Aug 2017 08:05:18 +0200 Subject: [PATCH] Pass proper captures to members of RuleSet --- .../iterative/rule/test/RuleTester.java | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index ed2f6ee581de..e9edaa287b39 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -17,6 +17,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Match; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -35,6 +36,7 @@ import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; import static java.util.Optional.empty; import static java.util.stream.Collectors.toSet; @@ -96,7 +98,7 @@ public ConnectorId getCurrentConnectorId() return queryRunner.inTransaction(transactionSession -> metadata.getCatalogHandle(transactionSession, session.getCatalog().get())).get(); } - private class RuleSetAdapter + private static class RuleSetAdapter implements Rule { private final RuleSet ruleSet; @@ -116,19 +118,33 @@ public Pattern getPattern() public Optional apply(PlanNode node, Captures captures, Context context) { PlanNodeMatcher planNodeMatcher = new PlanNodeMatcher(context.getLookup()); - Set matching = ruleSet.rules().stream() - .filter(rule -> matches(rule.getPattern(), node, planNodeMatcher)) + Set matching = ruleSet.rules().stream() + .map(rule -> new RuleMatch(rule, planNodeMatcher.match(rule.getPattern(), node))) + .filter(ruleMatch -> ruleMatch.match.isPresent()) .collect(toSet()); + if (matching.size() == 0) { return empty(); } - return getOnlyElement(matching).apply(node, captures, context); + return getOnlyElement(matching).apply(context); } - private boolean matches(Pattern pattern, PlanNode node, PlanNodeMatcher planNodeMatcher) + private static class RuleMatch { - return planNodeMatcher.match(pattern, node).isPresent(); + private final Rule rule; + private final Match match; + + private RuleMatch(Rule rule, Match match) + { + this.rule = requireNonNull(rule, "rule is null"); + this.match = requireNonNull(match, "match is null"); + } + + private Optional apply(Context context) + { + return rule.apply(match.value(), match.captures(), context); + } } } }