Skip to content

Commit

Permalink
Merge pull request #1790 from iamdefinitelyahuman/minmax-signing
Browse files Browse the repository at this point in the history
Minmax signing
  • Loading branch information
fubuloubu committed Dec 28, 2019
2 parents 2448287 + 56c5626 commit fb7ac96
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
52 changes: 52 additions & 0 deletions tests/parser/functions/test_minmax.py
Expand Up @@ -185,3 +185,55 @@ def foo() -> uint256:
lambda: get_contract_with_gas_estimation(code_2),
TypeMismatchException
)


def test_unsigned(get_contract_with_gas_estimation):
code = """
@public
def foo1() -> uint256:
return min(0, 2**255)
@public
def foo2() -> uint256:
return min(2**255, 0)
@public
def foo3() -> uint256:
return max(0, 2**255)
@public
def foo4() -> uint256:
return max(2**255, 0)
"""

c = get_contract_with_gas_estimation(code)
assert c.foo1() == 0
assert c.foo2() == 0
assert c.foo3() == 2**255
assert c.foo4() == 2**255


def test_signed(get_contract_with_gas_estimation):
code = """
@public
def foo1() -> int128:
return min(MIN_INT128, MAX_INT128)
@public
def foo2() -> int128:
return min(MAX_INT128, MIN_INT128)
@public
def foo3() -> int128:
return max(MIN_INT128, MAX_INT128)
@public
def foo4() -> int128:
return max(MAX_INT128, MIN_INT128)
"""

c = get_contract_with_gas_estimation(code)
assert c.foo1() == -2**127
assert c.foo2() == -2**127
assert c.foo3() == 2**127-1
assert c.foo4() == 2**127-1
13 changes: 6 additions & 7 deletions vyper/functions/functions.py
Expand Up @@ -1162,15 +1162,15 @@ def create_forwarder_to(expr, args, kwargs, context):

@signature(('int128', 'decimal', 'uint256'), ('int128', 'decimal', 'uint256'))
def _min(expr, args, kwargs, context):
return minmax(expr, args, kwargs, context, True)
return minmax(expr, args, kwargs, context, 'gt')


@signature(('int128', 'decimal', 'uint256'), ('int128', 'decimal', 'uint256'))
def _max(expr, args, kwargs, context):
return minmax(expr, args, kwargs, context, False)
return minmax(expr, args, kwargs, context, 'lt')


def minmax(expr, args, kwargs, context, is_min):
def minmax(expr, args, kwargs, context, comparator):
def _can_compare_with_uint256(operand):
if operand.typ.typ == 'uint256':
return True
Expand All @@ -1181,11 +1181,10 @@ def _can_compare_with_uint256(operand):
left, right = args[0], args[1]
if not are_units_compatible(left.typ, right.typ) and not are_units_compatible(right.typ, left.typ): # noqa: E501
raise TypeMismatchException("Units must be compatible", expr)
if left.typ.typ == 'uint256':
comparator = 'gt' if is_min else 'lt'
else:
comparator = 'sgt' if is_min else 'slt'
if left.typ.typ == right.typ.typ:
if left.typ.typ != 'uint256':
# if comparing like types that are not uint256, use SLT or SGT
comparator = f's{comparator}'
o = ['if', [comparator, '_l', '_r'], '_r', '_l']
otyp = left.typ
otyp.is_literal = False
Expand Down

0 comments on commit fb7ac96

Please sign in to comment.