Skip to content

Commit

Permalink
Add rewriter to desugar TRY using internal try function
Browse files Browse the repository at this point in the history
  • Loading branch information
shixuan-fan committed Sep 15, 2017
1 parent 6e25209 commit 46ee498
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 0 deletions.
@@ -0,0 +1,48 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.TryExpression;
import com.google.common.collect.ImmutableList;

public class DesugarTryExpressionRewriter
{
private DesugarTryExpressionRewriter() {}

public static Expression rewrite(Expression expression)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression);
}

private static class Visitor
extends ExpressionRewriter<Void>
{
@Override
public Expression rewriteTryExpression(TryExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Expression expression = treeRewriter.rewrite(node.getInnerExpression(), context);

return new FunctionCall(
QualifiedName.of("$internal$try"),
ImmutableList.of(new LambdaExpression(ImmutableList.of(), expression)));
}
}
}
Expand Up @@ -23,6 +23,7 @@
import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN;
import com.facebook.presto.sql.planner.iterative.rule.DesugarAtTimeZone;
import com.facebook.presto.sql.planner.iterative.rule.DesugarLambdaExpression;
import com.facebook.presto.sql.planner.iterative.rule.DesugarTryExpression;
import com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins;
import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit;
import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroSample;
Expand Down Expand Up @@ -188,6 +189,7 @@ public PlanOptimizers(Metadata metadata, SqlParser sqlParser, FeaturesConfig fea
ImmutableSet.<Rule<?>>builder()
.addAll(new DesugarLambdaExpression().rules())
.addAll(new DesugarAtTimeZone(metadata, sqlParser).rules())
.addAll(new DesugarTryExpression().rules())
.build()),
new IterativeOptimizer(
stats,
Expand Down
@@ -0,0 +1,25 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.sql.planner.DesugarTryExpressionRewriter;

public class DesugarTryExpression
extends ExpressionRewriteRuleSet
{
public DesugarTryExpression()
{
super((expression, context) -> DesugarTryExpressionRewriter.rewrite(expression));
}
}
Expand Up @@ -18,6 +18,7 @@
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DesugarAtTimeZoneRewriter;
import com.facebook.presto.sql.planner.DesugarTryExpressionRewriter;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
Expand Down Expand Up @@ -189,6 +190,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext<Void> context)

private Expression desugar(Expression expression)
{
expression = DesugarTryExpressionRewriter.rewrite(expression);
expression = LambdaCaptureDesugaringRewriter.rewrite(expression, symbolAllocator.getTypes(), symbolAllocator);
expression = DesugarAtTimeZoneRewriter.rewrite(expression, session, metadata, sqlParser, symbolAllocator);

Expand Down
@@ -0,0 +1,51 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.DecimalLiteral;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.TryExpression;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Type.ADD;
import static org.testng.Assert.assertEquals;

public class TestDesugarTryExpressionRewriter
extends BaseRuleTest
{
@Test
public void testTryExpressionDesugaringRewriter()
{
// 1 + try(2)
Expression before = new ArithmeticBinaryExpression(
ADD,
new DecimalLiteral("1"),
new TryExpression(new DecimalLiteral("2")));

// 1 + try_function(() -> 2)
Expression after = new ArithmeticBinaryExpression(
ADD,
new DecimalLiteral("1"),
new FunctionCall(
QualifiedName.of("$internal$try"),
ImmutableList.of(new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2")))));
assertEquals(DesugarTryExpressionRewriter.rewrite(before), after);
}
}

0 comments on commit 46ee498

Please sign in to comment.