Skip to content

Commit

Permalink
Migrate type calculation to BigInteger to avoid overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedukow authored and martint committed Jan 14, 2017
1 parent 42a2522 commit b088854
Showing 1 changed file with 28 additions and 25 deletions.
Expand Up @@ -32,6 +32,7 @@
import org.antlr.v4.runtime.atn.PredictionMode; import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException; import org.antlr.v4.runtime.misc.ParseCancellationException;


import java.math.BigInteger;
import java.util.Map; import java.util.Map;


import static com.facebook.presto.type.TypeCalculationParser.ASTERISK; import static com.facebook.presto.type.TypeCalculationParser.ASTERISK;
Expand Down Expand Up @@ -62,7 +63,9 @@ public static Long calculateLiteralValue(
{ {
try { try {
ParserRuleContext tree = parseTypeCalculation(calculation); ParserRuleContext tree = parseTypeCalculation(calculation);
return new CalculateTypeVisitor(inputs).visit(tree); CalculateTypeVisitor visitor = new CalculateTypeVisitor(inputs);
BigInteger result = visitor.visit(tree);
return result.longValueExact();
} }
catch (StackOverflowError e) { catch (StackOverflowError e) {
throw new ParsingException("Type calculation is too large (stack overflow while parsing)"); throw new ParsingException("Type calculation is too large (stack overflow while parsing)");
Expand Down Expand Up @@ -127,7 +130,7 @@ protected Boolean aggregateResult(Boolean aggregate, Boolean nextResult)
} }


private static class CalculateTypeVisitor private static class CalculateTypeVisitor
extends TypeCalculationBaseVisitor<Long> extends TypeCalculationBaseVisitor<BigInteger>
{ {
private final Map<String, Long> inputs; private final Map<String, Long> inputs;


Expand All @@ -137,82 +140,82 @@ public CalculateTypeVisitor(Map<String, Long> inputs)
} }


@Override @Override
public Long visitTypeCalculation(TypeCalculationContext ctx) public BigInteger visitTypeCalculation(TypeCalculationContext ctx)
{ {
return visit(ctx.expression()); return visit(ctx.expression());
} }


@Override @Override
public Long visitArithmeticBinary(ArithmeticBinaryContext ctx) public BigInteger visitArithmeticBinary(ArithmeticBinaryContext ctx)
{ {
Long left = visit(ctx.left); BigInteger left = visit(ctx.left);
Long right = visit(ctx.right); BigInteger right = visit(ctx.right);
switch (ctx.operator.getType()) { switch (ctx.operator.getType()) {
case PLUS: case PLUS:
return left + right; return left.add(right);
case MINUS: case MINUS:
return left - right; return left.subtract(right);
case ASTERISK: case ASTERISK:
return left * right; return left.multiply(right);
case SLASH: case SLASH:
return left / right; return left.divide(right);
default: default:
throw new IllegalStateException("Unsupported binary operator " + ctx.operator.getText()); throw new IllegalStateException("Unsupported binary operator " + ctx.operator.getText());
} }
} }


@Override @Override
public Long visitArithmeticUnary(ArithmeticUnaryContext ctx) public BigInteger visitArithmeticUnary(ArithmeticUnaryContext ctx)
{ {
Long value = visit(ctx.expression()); BigInteger value = visit(ctx.expression());
switch (ctx.operator.getType()) { switch (ctx.operator.getType()) {
case PLUS: case PLUS:
return value; return value;
case MINUS: case MINUS:
return -1L * value; return value.negate();
default: default:
throw new IllegalStateException("Unsupported unary operator " + ctx.operator.getText()); throw new IllegalStateException("Unsupported unary operator " + ctx.operator.getText());
} }
} }


@Override @Override
public Long visitBinaryFunction(BinaryFunctionContext ctx) public BigInteger visitBinaryFunction(BinaryFunctionContext ctx)
{ {
Long left = visit(ctx.left); BigInteger left = visit(ctx.left);
Long right = visit(ctx.right); BigInteger right = visit(ctx.right);
switch (ctx.binaryFunctionName().name.getType()) { switch (ctx.binaryFunctionName().name.getType()) {
case MIN: case MIN:
return Math.min(left, right); return left.min(right);
case MAX: case MAX:
return Math.max(left, right); return left.max(right);
default: default:
throw new IllegalArgumentException("Unsupported binary function " + ctx.binaryFunctionName().getText()); throw new IllegalArgumentException("Unsupported binary function " + ctx.binaryFunctionName().getText());
} }
} }


@Override @Override
public Long visitNumericLiteral(NumericLiteralContext ctx) public BigInteger visitNumericLiteral(NumericLiteralContext ctx)
{ {
return Long.parseLong(ctx.INTEGER_VALUE().getText()); return new BigInteger(ctx.INTEGER_VALUE().getText());
} }


@Override @Override
public Long visitNullLiteral(NullLiteralContext ctx) public BigInteger visitNullLiteral(NullLiteralContext ctx)
{ {
return 0L; return BigInteger.ZERO;
} }


@Override @Override
public Long visitIdentifier(IdentifierContext ctx) public BigInteger visitIdentifier(IdentifierContext ctx)
{ {
String identifier = ctx.getText(); String identifier = ctx.getText();
Long value = inputs.get(identifier); Long value = inputs.get(identifier);
checkState(value != null, "value for variable '%s' is not specified in the inputs", identifier); checkState(value != null, "value for variable '%s' is not specified in the inputs", identifier);
return value; return BigInteger.valueOf(value);
} }


@Override @Override
public Long visitParenthesizedExpression(ParenthesizedExpressionContext ctx) public BigInteger visitParenthesizedExpression(ParenthesizedExpressionContext ctx)
{ {
return visit(ctx.expression()); return visit(ctx.expression());
} }
Expand Down

0 comments on commit b088854

Please sign in to comment.