Skip to content

Commit

Permalink
Preserve unary '+' in AST
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Jan 8, 2015
1 parent 44d32d3 commit 0feb144
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 86 deletions.
Expand Up @@ -33,7 +33,7 @@
import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -343,7 +343,7 @@ protected Boolean visitQualifiedNameReference(QualifiedNameReference node, Void
} }


@Override @Override
protected Boolean visitNegativeExpression(NegativeExpression node, Void context) protected Boolean visitArithmeticUnary(ArithmeticUnaryExpression node, Void context)
{ {
return process(node.getValue(), context); return process(node.getValue(), context);
} }
Expand Down
Expand Up @@ -48,7 +48,7 @@
import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -437,9 +437,16 @@ protected Type visitCoalesceExpression(CoalesceExpression node, AnalysisContext
} }


@Override @Override
protected Type visitNegativeExpression(NegativeExpression node, AnalysisContext context) protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, AnalysisContext context)
{ {
return getOperator(context, node, OperatorType.NEGATION, node.getValue()); if (node.getSign() == ArithmeticUnaryExpression.Sign.MINUS) {
return getOperator(context, node, OperatorType.NEGATION, node.getValue());
}

Type type = process(node.getValue(), context);
expressionTypes.put(node, type);

return type;
} }


@Override @Override
Expand Down
Expand Up @@ -41,7 +41,7 @@
import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -406,30 +406,37 @@ else if (!found && (Boolean) invokeOperator(OperatorType.EQUAL, types(node.getVa
} }


@Override @Override
protected Object visitNegativeExpression(NegativeExpression node, Object context) protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object context)
{ {
Object value = process(node.getValue(), context); Object value = process(node.getValue(), context);
if (value == null) { if (value == null) {
return null; return null;
} }
if (value instanceof Expression) { if (value instanceof Expression) {
return new NegativeExpression(toExpression(value, expressionTypes.get(node.getValue()))); return new ArithmeticUnaryExpression(node.getSign(), toExpression(value, expressionTypes.get(node.getValue())));
} }


FunctionInfo operatorInfo = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue())); switch (node.getSign()) {
case PLUS:
return value;
case MINUS:
FunctionInfo operatorInfo = metadata.resolveOperator(OperatorType.NEGATION, types(node.getValue()));


MethodHandle handle = operatorInfo.getMethodHandle(); MethodHandle handle = operatorInfo.getMethodHandle();
if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) {
handle = handle.bindTo(session); handle = handle.bindTo(session);
} }
try { try {
return handle.invokeWithArguments(value); return handle.invokeWithArguments(value);
} }
catch (Throwable throwable) { catch (Throwable throwable) {
Throwables.propagateIfInstanceOf(throwable, RuntimeException.class); Throwables.propagateIfInstanceOf(throwable, RuntimeException.class);
Throwables.propagateIfInstanceOf(throwable, Error.class); Throwables.propagateIfInstanceOf(throwable, Error.class);
throw new RuntimeException(throwable.getMessage(), throwable); throw new RuntimeException(throwable.getMessage(), throwable);
}
} }

throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
} }


@Override @Override
Expand Down
Expand Up @@ -30,7 +30,7 @@
import com.facebook.presto.sql.tree.IntervalLiteral; import com.facebook.presto.sql.tree.IntervalLiteral;
import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.Literal;
import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.StringLiteral; import com.facebook.presto.sql.tree.StringLiteral;
Expand Down Expand Up @@ -110,7 +110,7 @@ public static Expression toExpression(Object object, Type type)
return new FunctionCall(new QualifiedName("nan"), ImmutableList.<Expression>of()); return new FunctionCall(new QualifiedName("nan"), ImmutableList.<Expression>of());
} }
else if (value.equals(Double.NEGATIVE_INFINITY)) { else if (value.equals(Double.NEGATIVE_INFINITY)) {
return new NegativeExpression(new FunctionCall(new QualifiedName("infinity"), ImmutableList.<Expression>of())); return ArithmeticUnaryExpression.negative(new FunctionCall(new QualifiedName("infinity"), ImmutableList.<Expression>of()));
} }
else if (value.equals(Double.POSITIVE_INFINITY)) { else if (value.equals(Double.POSITIVE_INFINITY)) {
return new FunctionCall(new QualifiedName("infinity"), ImmutableList.<Expression>of()); return new FunctionCall(new QualifiedName("infinity"), ImmutableList.<Expression>of());
Expand Down
Expand Up @@ -21,6 +21,7 @@
import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.relational.optimizer.ExpressionOptimizer; import com.facebook.presto.sql.relational.optimizer.ExpressionOptimizer;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.AstVisitor; import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.BetweenPredicate; import com.facebook.presto.sql.tree.BetweenPredicate;
Expand All @@ -42,7 +43,6 @@
import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NegativeExpression;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.NullLiteral;
Expand Down Expand Up @@ -282,14 +282,21 @@ protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, V
} }


@Override @Override
protected RowExpression visitNegativeExpression(NegativeExpression node, Void context) protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Void context)
{ {
RowExpression expression = process(node.getValue(), context); RowExpression expression = process(node.getValue(), context);


return call( switch (node.getSign()) {
arithmeticNegationSignature(types.get(node), expression.getType()), case PLUS:
types.get(node), return expression;
expression); case MINUS:
return call(
arithmeticNegationSignature(types.get(node), expression.getType()),
types.get(node),
expression);
}

throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign());
} }


@Override @Override
Expand Down
Expand Up @@ -40,7 +40,7 @@
import com.facebook.presto.sql.tree.LikePredicate; import com.facebook.presto.sql.tree.LikePredicate;
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -309,11 +309,20 @@ protected String visitCoalesceExpression(CoalesceExpression node, Void context)
} }


@Override @Override
protected String visitNegativeExpression(NegativeExpression node, Void context) protected String visitArithmeticUnary(ArithmeticUnaryExpression node, Void context)
{ {
String value = process(node.getValue(), null); String value = process(node.getValue(), null);
String separator = value.startsWith("-") ? " " : "";
return "-" + separator + value; switch (node.getSign()) {
case MINUS:
// this is to avoid turning a sequence of "-" into a comment (i.e., "-- comment")
String separator = value.startsWith("-") ? " " : "";
return "-" + separator + value;
case PLUS:
return "+" + value;
default:
throw new UnsupportedOperationException("Unsupported sign: " + node.getSign());
}
} }


@Override @Override
Expand Down
Expand Up @@ -56,7 +56,7 @@
import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.NaturalJoin;
import com.facebook.presto.sql.tree.NegativeExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullIfExpression; import com.facebook.presto.sql.tree.NullIfExpression;
Expand Down Expand Up @@ -664,13 +664,16 @@ public Node visitExists(@NotNull SqlBaseParser.ExistsContext context)
@Override @Override
public Node visitArithmeticUnary(@NotNull SqlBaseParser.ArithmeticUnaryContext context) public Node visitArithmeticUnary(@NotNull SqlBaseParser.ArithmeticUnaryContext context)
{ {
Expression result = (Expression) visit(context.valueExpression()); Expression child = (Expression) visit(context.valueExpression());


if (context.operator.getType() == SqlBaseLexer.MINUS) { switch (context.operator.getType()) {
result = new NegativeExpression(result); case SqlBaseLexer.MINUS:
return ArithmeticUnaryExpression.negative(child);
case SqlBaseLexer.PLUS:
return ArithmeticUnaryExpression.positive(child);
default:
throw new UnsupportedOperationException("Unsupported sign: " + context.operator.getText());
} }

return result;
} }


@Override @Override
Expand Down
Expand Up @@ -13,25 +13,52 @@
*/ */
package com.facebook.presto.sql.tree; package com.facebook.presto.sql.tree;


public class NegativeExpression import static com.google.common.base.Preconditions.checkNotNull;

public class ArithmeticUnaryExpression
extends Expression extends Expression
{ {
public enum Sign {
PLUS,
MINUS
}

private final Expression value; private final Expression value;
private final Sign sign;


public NegativeExpression(Expression value) public ArithmeticUnaryExpression(Sign sign, Expression value)
{ {
checkNotNull(value, "value is null");
checkNotNull(sign, "sign is null");

this.value = value; this.value = value;
this.sign = sign;
}

public static ArithmeticUnaryExpression positive(Expression value)
{
return new ArithmeticUnaryExpression(Sign.PLUS, value);
}

public static ArithmeticUnaryExpression negative(Expression value)
{
return new ArithmeticUnaryExpression(Sign.MINUS, value);
} }


public Expression getValue() public Expression getValue()
{ {
return value; return value;
} }


public Sign getSign()
{
return sign;
}

@Override @Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context) public <R, C> R accept(AstVisitor<R, C> visitor, C context)
{ {
return visitor.visitNegativeExpression(this, context); return visitor.visitArithmeticUnary(this, context);
} }


@Override @Override
Expand All @@ -44,8 +71,11 @@ public boolean equals(Object o)
return false; return false;
} }


NegativeExpression that = (NegativeExpression) o; ArithmeticUnaryExpression that = (ArithmeticUnaryExpression) o;


if (sign != that.sign) {
return false;
}
if (!value.equals(that.value)) { if (!value.equals(that.value)) {
return false; return false;
} }
Expand All @@ -56,6 +86,8 @@ public boolean equals(Object o)
@Override @Override
public int hashCode() public int hashCode()
{ {
return value.hashCode(); int result = value.hashCode();
result = 31 * result + sign.hashCode();
return result;
} }
} }
Expand Up @@ -272,7 +272,7 @@ protected R visitNullLiteral(NullLiteral node, C context)
return visitLiteral(node, context); return visitLiteral(node, context);
} }


protected R visitNegativeExpression(NegativeExpression node, C context) protected R visitArithmeticUnary(ArithmeticUnaryExpression node, C context)
{ {
return visitExpression(node, context); return visitExpression(node, context);
} }
Expand Down
Expand Up @@ -250,7 +250,7 @@ protected R visitIfExpression(IfExpression node, C context)
} }


@Override @Override
protected R visitNegativeExpression(NegativeExpression node, C context) protected R visitArithmeticUnary(ArithmeticUnaryExpression node, C context)
{ {
return process(node.getValue(), context); return process(node.getValue(), context);
} }
Expand Down
Expand Up @@ -20,7 +20,7 @@ public Expression rewriteExpression(Expression node, C context, ExpressionTreeRe
return null; return null;
} }


public Expression rewriteNegativeExpression(NegativeExpression node, C context, ExpressionTreeRewriter<C> treeRewriter) public Expression rewriteArithmeticUnary(ArithmeticUnaryExpression node, C context, ExpressionTreeRewriter<C> treeRewriter)
{ {
return rewriteExpression(node, context, treeRewriter); return rewriteExpression(node, context, treeRewriter);
} }
Expand Down
Expand Up @@ -69,18 +69,18 @@ protected Expression visitExpression(Expression node, Context<C> context)
} }


@Override @Override
protected Expression visitNegativeExpression(NegativeExpression node, Context<C> context) protected Expression visitArithmeticUnary(ArithmeticUnaryExpression node, Context<C> context)
{ {
if (!context.isDefaultRewrite()) { if (!context.isDefaultRewrite()) {
Expression result = rewriter.rewriteNegativeExpression(node, context.get(), ExpressionTreeRewriter.this); Expression result = rewriter.rewriteArithmeticUnary(node, context.get(), ExpressionTreeRewriter.this);
if (result != null) { if (result != null) {
return result; return result;
} }
} }


Expression child = rewrite(node.getValue(), context.get()); Expression child = rewrite(node.getValue(), context.get());
if (child != node.getValue()) { if (child != node.getValue()) {
return new NegativeExpression(child); return new ArithmeticUnaryExpression(node.getSign(), child);
} }


return node; return node;
Expand Down

0 comments on commit 0feb144

Please sign in to comment.