Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: edge cases in safepow #2983

Merged
merged 6 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions tests/fuzzing/test_exponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_exp_uint256(get_contract, assert_tx_failed, power):
def foo(a: uint256) -> uint256:
return a ** {power}
"""
max_base = calculate_largest_base(power, 256, False)
_min_base, max_base = calculate_largest_base(power, 256, False)
assert max_base ** power < 2 ** 256
assert (max_base + 1) ** power >= 2 ** 256

Expand All @@ -31,21 +31,46 @@ def test_exp_int128(get_contract, assert_tx_failed, power):
def foo(a: int128) -> int128:
return a ** {power}
"""
max_base = calculate_largest_base(power, 128, True)
min_base, max_base = calculate_largest_base(power, 128, True)

assert -(2 ** 127) <= max_base ** power < 2 ** 127
assert -(2 ** 127) <= (-max_base) ** power < 2 ** 127
assert -(2 ** 127) <= min_base ** power < 2 ** 127

assert not -(2 ** 127) <= (max_base + 1) ** power < 2 ** 127
assert not -(2 ** 127) <= (-(max_base + 1)) ** power < 2 ** 127
assert not -(2 ** 127) <= (min_base - 1) ** power < 2 ** 127

c = get_contract(code)

c.foo(max_base)
c.foo(-max_base)
c.foo(min_base)

assert_tx_failed(lambda: c.foo(max_base + 1))
assert_tx_failed(lambda: c.foo(-max_base - 1))
assert_tx_failed(lambda: c.foo(min_base - 1))


@pytest.mark.fuzzing
@pytest.mark.parametrize("power", range(2, 15))
def test_exp_int16(get_contract, assert_tx_failed, power):
code = f"""
@external
def foo(a: int16) -> int16:
return a ** {power}
"""
min_base, max_base = calculate_largest_base(power, 16, True)

assert -(2 ** 15) <= max_base ** power < 2 ** 15
assert -(2 ** 15) <= min_base ** power < 2 ** 15

assert not -(2 ** 15) <= (max_base + 1) ** power < 2 ** 15
assert not -(2 ** 15) <= (min_base - 1) ** power < 2 ** 15

c = get_contract(code)

c.foo(max_base)
c.foo(min_base)

assert_tx_failed(lambda: c.foo(max_base + 1))
assert_tx_failed(lambda: c.foo(min_base - 1))


@pytest.mark.fuzzing
Expand Down
54 changes: 54 additions & 0 deletions tests/parser/types/numbers/test_signed_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,60 @@ def foo(x: {typ}) -> {typ}:
assert c.foo(hi) == 1


def test_exponent_base_minus_one(get_contract):
# #2986
code = """
@external
def foo() -> int256:
x: int256 = 4
y: int256 = -1 ** x
return y
"""
c = get_contract(code)
assert c.foo() == -1


# TODO: make this test pass
@pytest.mark.parametrize("base", (0, 1))
def test_exponent_negative_power(get_contract, assert_tx_failed, base):
# #2985
code = f"""
@external
def bar() -> int16:
x: int16 = -2
return {base} ** x
"""
get_contract(code)
# known bug: 2985
# assert_tx_failed(lambda: c.bar())


def test_exponent_min_int16(get_contract):
# #2987
code = """
@external
def foo() -> int16:
x: int16 = -8
y: int16 = x ** 5
return y
"""
c = get_contract(code)
assert c.foo() == -(2 ** 15)


@pytest.mark.parametrize("power", [0, 1])
def test_exponent_power_zero_one(get_contract, power):
# #2989
code = f"""
@external
def foo() -> int256:
x: int256 = 2
return x ** {power}
"""
c = get_contract(code)
assert c.foo() == 2 ** power


@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS)
def test_exponent(get_contract, assert_tx_failed, typ, lo, hi, bits):
code = f"""
Expand Down
30 changes: 30 additions & 0 deletions tests/parser/types/numbers/test_unsigned_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,36 @@ def foo(x: {typ}) -> {typ}:
assert c.foo(hi) == 1


@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS)
def test_exponent_power_zero(get_contract, typ, lo, hi, bits):
# #2984
code = f"""
@external
def foo(x: {typ}) -> {typ}:
return x ** 0
"""
c = get_contract(code)
assert c.foo(0) == 1
assert c.foo(1) == 1
assert c.foo(42) == 1
assert c.foo(hi) == 1


@pytest.mark.parametrize("typ,lo,hi,bits", PARAMS)
def test_exponent_power_one(get_contract, typ, lo, hi, bits):
# #2984
code = f"""
@external
def foo(x: {typ}) -> {typ}:
return x ** 1
"""
c = get_contract(code)
assert c.foo(0) == 0
assert c.foo(1) == 1
assert c.foo(42) == 42
assert c.foo(hi) == hi


ARITHMETIC_OPS = {
"+": operator.add,
"-": operator.sub,
Expand Down
48 changes: 35 additions & 13 deletions vyper/codegen/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import decimal
import math
from typing import Tuple

from vyper.codegen.core import clamp, clamp_basetype
from vyper.codegen.ir_node import IRnode
Expand Down Expand Up @@ -38,11 +39,12 @@ def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int:
raise TypeCheckFailure("Value is too small and will always throw")

a_is_negative = a < 0
a = abs(a) # No longer need to know if it's signed or not

if a in (0, 1):
raise CompilerPanic("Exponential operation is useless!")

a = abs(a) # No longer need to know if it's signed or not

# NOTE: There is an edge case if `a` were left signed where the following
# operation would not work (`ln(a)` is undefined if `a <= 0`)
b = int(decimal.Decimal(value_bits) / (decimal.Decimal(a).ln() / decimal.Decimal(2).ln()))
Expand Down Expand Up @@ -76,7 +78,7 @@ def calculate_largest_power(a: int, num_bits: int, is_signed: bool) -> int:
return b # Exact


def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> Tuple[int, int]:
"""
For a given power `b`, compute the maximum base `a` that will not produce an
overflow in the equation `a ** b`
Expand All @@ -92,9 +94,13 @@ def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:

Returns
-------
int
Largest possible value for `a` where the result does not overflow
`num_bits`
Tuple[int, int]
Smallest and largest possible values for `a` where the result
does not overflow `num_bits`.

Note that the lower and upper bounds are not always negatives of
each other, due to lower/upper bounds for int_<value_bits> being
slightly asymmetric.
"""
if num_bits % 8:
raise CompilerPanic("Type is not a modulo of 8")
Expand All @@ -104,8 +110,8 @@ def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
value_bits = num_bits - (1 if is_signed else 0)
if b > value_bits:
raise TypeCheckFailure("Value is too large and will always throw")
elif b < 2:
return 2 ** value_bits - 1 # Maximum value for type
if b < 2:
raise CompilerPanic("Exponential operation is useless!")

# CMC 2022-05-06 TODO we should be able to do this with algebra
# instead of looping):
Expand All @@ -126,7 +132,15 @@ def calculate_largest_base(b: int, num_bits: int, is_signed: bool) -> int:
a -= 1
num_iterations += 1
assert num_iterations < 10000
return a

if not is_signed:
return 0, a

if (a + 1) ** b == (2 ** value_bits):
# edge case: lower bound is slightly wider than upper bound
return -(a + 1), a
else:
return -a, a


# def safe_add(x: IRnode, y: IRnode) -> IRnode:
Expand Down Expand Up @@ -336,16 +350,24 @@ def safe_pow(x, y):
if x.value == 0:
return IRnode.from_list(["iszero", y])

upper_bound = calculate_largest_power(x.value, num_info.bits, num_info.is_signed) + 1
upper_bound = calculate_largest_power(x.value, num_info.bits, num_info.is_signed)
# for signed integers, this also prevents negative values
ok = ["lt", y, upper_bound]
ok = ["le", y, upper_bound]

elif y.is_literal:
upper_bound = calculate_largest_base(y.value, num_info.bits, num_info.is_signed) + 1
# cannot pass 1 or 0 to `calculate_largest_base`
if y.value == 1:
return x
if y.value == 0:
return IRnode.from_list([1])

lower_bound, upper_bound = calculate_largest_base(
y.value, num_info.bits, num_info.is_signed
)
if num_info.is_signed:
ok = ["and", ["slt", x, upper_bound], ["sgt", x, -upper_bound]]
ok = ["and", ["sge", x, lower_bound], ["sle", x, upper_bound]]
else:
ok = ["lt", x, upper_bound]
ok = ["le", x, upper_bound]
else:
# `a ** b` where neither `a` or `b` are known
# TODO this is currently unreachable, once we implement a way to do it safely
Expand Down