Skip to content
Permalink
Browse files Browse the repository at this point in the history
Merge pull request from GHSA-7vrm-3jc8-5wwm
* add more tests for string comparison

explicitly test the codepath with <= 32 bytes

* refactor keccak256 helper a bit

* fix bytestring equality

existing bytestring equality checks do not check length equality or for
dirty bytes.
  • Loading branch information
charles-cooper committed Apr 2, 2022
1 parent 0807a60 commit 2c73f83
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 104 deletions.
30 changes: 30 additions & 0 deletions tests/parser/functions/test_slice.py
Expand Up @@ -210,6 +210,36 @@ def ret10_slice() -> Bytes[10]:
assert c.ret10_slice() == b"A"


def test_slice_equality(get_contract):
# test for equality with dirty bytes
code = """
@external
def assert_eq() -> bool:
dirty_bytes: String[4] = "abcd"
dirty_bytes = slice(dirty_bytes, 0, 3)
clean_bytes: String[4] = "abc"
return dirty_bytes == clean_bytes
"""

c = get_contract(code)
assert c.assert_eq()


def test_slice_inequality(get_contract):
# test for equality with dirty bytes
code = """
@external
def assert_ne() -> bool:
dirty_bytes: String[4] = "abcd"
dirty_bytes = slice(dirty_bytes, 0, 3)
clean_bytes: String[4] = "abcd"
return dirty_bytes != clean_bytes
"""

c = get_contract(code)
assert c.assert_ne()


def test_slice_convert(get_contract):
# test slice of converting between bytes32 and Bytes
code = """
Expand Down
159 changes: 101 additions & 58 deletions tests/parser/types/test_string.py
Expand Up @@ -168,120 +168,165 @@ def test(a: uint256, b: String[50] = "foo") -> Bytes[100]:
assert c.test(12345, "bar")[-3:] == b"bar"


def test_string_equality(get_contract_with_gas_estimation):
code = """
_compA: String[100]
_compB: String[100]
string_equality_tests = [
(
100,
"The quick brown fox jumps over the lazy dog",
"The quick brown fox jumps over the lazy hog",
),
# check <= 32 codepath
(32, "abc", "abc\0"),
(32, "abc", "abc\1"), # use a_init dirty bytes
(32, "abc\2", "abc"), # use b_init dirty bytes
(32, "", "\0"),
(32, "", "\1"),
(33, "", "\1"),
(33, "", "\0"),
]


@pytest.mark.parametrize("len_,a,b", string_equality_tests)
def test_string_equality(get_contract_with_gas_estimation, len_, a, b):
# fixtures to initialize strings with dirty bytes
a_init = "\\1" * len_
b_init = "\\2" * len_
string1 = a.encode("unicode_escape").decode("utf-8")
string2 = b.encode("unicode_escape").decode("utf-8")
code = f"""
a: String[{len_}]
b: String[{len_}]
@external
def equal_true() -> bool:
compA: String[100] = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy dog"
return compA == compB
a: String[{len_}] = "{a_init}"
b: String[{len_}] = "{b_init}"
a = "{string1}"
b = "{string1}"
return a == b
@external
def equal_false() -> bool:
compA: String[100] = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy hog"
return compA == compB
a: String[{len_}] = "{a_init}"
b: String[{len_}] = "{b_init}"
a = "{string1}"
b = "{string2}"
return a == b
@external
def not_equal_true() -> bool:
compA: String[100] = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy hog"
return compA != compB
a: String[{len_}] = "{a_init}"
b: String[{len_}] = "{b_init}"
a = "{string1}"
b = "{string2}"
return a != b
@external
def not_equal_false() -> bool:
compA: String[100] = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy dog"
return compA != compB
a: String[{len_}] = "{a_init}"
b: String[{len_}] = "{b_init}"
a = "{string1}"
b = "{string1}"
return a != b
@external
def literal_equal_true() -> bool:
return "The quick brown fox jumps over the lazy dog" == \
"The quick brown fox jumps over the lazy dog"
return "{string1}" == "{string1}"
@external
def literal_equal_false() -> bool:
return "The quick brown fox jumps over the lazy dog" == \
"The quick brown fox jumps over the lazy hog"
return "{string1}" == "{string2}"
@external
def literal_not_equal_true() -> bool:
return "The quick brown fox jumps over the lazy dog" != \
"The quick brown fox jumps over the lazy hog"
return "{string1}" != "{string2}"
@external
def literal_not_equal_false() -> bool:
return "The quick brown fox jumps over the lazy dog" != \
"The quick brown fox jumps over the lazy dog"
return "{string1}" != "{string1}"
@external
def storage_equal_true() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
self._compB = "The quick brown fox jumps over the lazy dog"
return self._compA == self._compB
self.a = "{a_init}"
self.b = "{b_init}"
self.a = "{string1}"
self.b = "{string1}"
return self.a == self.b
@external
def storage_equal_false() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
self._compB = "The quick brown fox jumps over the lazy hog"
return self._compA == self._compB
self.a = "{a_init}"
self.b = "{b_init}"
self.a = "{string1}"
self.b = "{string2}"
return self.a == self.b
@external
def storage_not_equal_true() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
self._compB = "The quick brown fox jumps over the lazy hog"
return self._compA != self._compB
self.a = "{a_init}"
self.b = "{b_init}"
self.a = "{string1}"
self.b = "{string2}"
return self.a != self.b
@external
def storage_not_equal_false() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
self._compB = "The quick brown fox jumps over the lazy dog"
return self._compA != self._compB
self.a = "{a_init}"
self.b = "{b_init}"
self.a = "{string1}"
self.b = "{string1}"
return self.a != self.b
@external
def string_compare_equal(str1: String[100], str2: String[100]) -> bool:
def string_compare_equal(str1: String[{len_}], str2: String[{len_}]) -> bool:
return str1 == str2
@external
def string_compare_not_equal(str1: String[100], str2: String[100]) -> bool:
def string_compare_not_equal(str1: String[{len_}], str2: String[{len_}]) -> bool:
return str1 != str2
@external
def compare_passed_storage_equal(str: String[100]) -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
return self._compA == str
def compare_passed_storage_equal(str_: String[{len_}]) -> bool:
self.a = "{a_init}"
self.a = "{string1}"
return self.a == str_
@external
def compare_passed_storage_not_equal(str: String[100]) -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
return self._compA != str
def compare_passed_storage_not_equal(str_: String[{len_}]) -> bool:
self.a = "{a_init}"
self.a = "{string1}"
return self.a != str_
@external
def compare_var_storage_equal_true() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy dog"
return self._compA == compB
self.a = "{a_init}"
b: String[{len_}] = "{b_init}"
self.a = "{string1}"
b = "{string1}"
return self.a == b
@external
def compare_var_storage_equal_false() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy hog"
return self._compA == compB
self.a = "{a_init}"
b: String[{len_}] = "{b_init}"
self.a = "{string1}"
b = "{string2}"
return self.a == b
@external
def compare_var_storage_not_equal_true() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy hog"
return self._compA != compB
self.a = "{a_init}"
b: String[{len_}] = "{b_init}"
self.a = "{string1}"
b = "{string2}"
return self.a != b
@external
def compare_var_storage_not_equal_false() -> bool:
self._compA = "The quick brown fox jumps over the lazy dog"
compB: String[100] = "The quick brown fox jumps over the lazy dog"
return self._compA != compB
self.a = "{a_init}"
b: String[{len_}] = "{b_init}"
self.a = "{string1}"
b = "{string1}"
return self.a != b
"""

c = get_contract_with_gas_estimation(code)
Expand All @@ -298,8 +343,6 @@ def compare_var_storage_not_equal_false() -> bool:
assert c.storage_not_equal_true() is True
assert c.storage_not_equal_false() is False

a = "The quick brown fox jumps over the lazy dog"
b = "The quick brown fox jumps over the lazy hog"
assert c.string_compare_equal(a, a) is True
assert c.string_compare_equal(a, b) is False
assert c.string_compare_not_equal(b, a) is True
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/core.py
Expand Up @@ -186,7 +186,7 @@ def copy_bytes(dst, src, length, length_bound):

with src.cache_when_complex("src") as (b1, src), length.cache_when_complex(
"copy_bytes_count"
) as (b2, length,), dst.cache_when_complex("dst") as (b3, dst):
) as (b2, length), dst.cache_when_complex("dst") as (b3, dst):

# fast code for common case where num bytes is small
# TODO expand this for more cases where num words is less than ~8
Expand Down
31 changes: 7 additions & 24 deletions vyper/codegen/expr.py
Expand Up @@ -5,8 +5,6 @@
from vyper.address_space import DATA, IMMUTABLES, MEMORY, STORAGE
from vyper.codegen import external_call, self_call
from vyper.codegen.core import (
LOAD,
bytes_data_ptr,
clamp_basetype,
ensure_in_memory,
get_dyn_array_count,
Expand Down Expand Up @@ -801,30 +799,15 @@ def parse_Compare(self):
left = Expr(self.expr.left, self.context).ir_node
right = Expr(self.expr.right, self.context).ir_node

length_mismatch = left.typ.maxlen != right.typ.maxlen
left_over_32 = left.typ.maxlen > 32
right_over_32 = right.typ.maxlen > 32

if length_mismatch or left_over_32 or right_over_32:
left_keccak = keccak256_helper(self.expr, left, self.context)
right_keccak = keccak256_helper(self.expr, right, self.context)

if op == "eq" or op == "ne":
return IRnode.from_list([op, left_keccak, right_keccak], typ="bool")

else:
return
left_keccak = keccak256_helper(self.expr, left, self.context)
right_keccak = keccak256_helper(self.expr, right, self.context)

if op not in ("eq", "ne"):
return # raises
else:

def load_bytearray(side):
return LOAD(bytes_data_ptr(side))

return IRnode.from_list(
# CMC 2022-03-24 TODO investigate this.
[op, load_bytearray(left), load_bytearray(right)],
typ="bool",
)
# use hash even for Bytes[N<=32], because there could be dirty
# bytes past the bytes data.
return IRnode.from_list([op, left_keccak, right_keccak], typ="bool")

# Compare other types.
elif is_numeric_type(left.typ) and is_numeric_type(right.typ):
Expand Down
41 changes: 20 additions & 21 deletions vyper/codegen/keccak256_helper.py
@@ -1,6 +1,6 @@
from math import ceil

from vyper.codegen.core import ensure_in_memory
from vyper.codegen.core import bytes_data_ptr, ensure_in_memory, get_bytearray_length
from vyper.codegen.ir_node import IRnode
from vyper.codegen.types import BaseType, ByteArrayLike, is_base_type
from vyper.exceptions import CompilerPanic
Expand All @@ -21,37 +21,36 @@ def _gas_bound(num_words):
return SHA3_BASE + num_words * SHA3_PER_WORD


def keccak256_helper(expr, ir_arg, context):
sub = ir_arg # TODO get rid of useless variable
_check_byteslike(sub.typ, expr)
def keccak256_helper(expr, to_hash, context):
_check_byteslike(to_hash.typ, expr)

# Can hash literals
# TODO this is dead code.
if isinstance(sub, bytes):
return IRnode.from_list(bytes_to_int(keccak256(sub)), typ=BaseType("bytes32"))
if isinstance(to_hash, bytes):
return IRnode.from_list(bytes_to_int(keccak256(to_hash)), typ=BaseType("bytes32"))

# Can hash bytes32 objects
if is_base_type(sub.typ, "bytes32"):
if is_base_type(to_hash.typ, "bytes32"):
return IRnode.from_list(
[
"seq",
["mstore", MemoryPositions.FREE_VAR_SPACE, sub],
["mstore", MemoryPositions.FREE_VAR_SPACE, to_hash],
["sha3", MemoryPositions.FREE_VAR_SPACE, 32],
],
typ=BaseType("bytes32"),
add_gas_estimate=_gas_bound(1),
)

sub = ensure_in_memory(sub, context)

return IRnode.from_list(
[
"with",
"_buf",
sub,
["sha3", ["add", "_buf", 32], ["mload", "_buf"]],
],
typ=BaseType("bytes32"),
annotation="keccak256",
add_gas_estimate=_gas_bound(ceil(sub.typ.maxlen / 32)),
)
to_hash = ensure_in_memory(to_hash, context)

with to_hash.cache_when_complex("buf") as (b1, to_hash):
data = bytes_data_ptr(to_hash)
len_ = get_bytearray_length(to_hash)
return b1.resolve(
IRnode.from_list(
["sha3", data, len_],
typ="bytes32",
annotation="keccak256",
add_gas_estimate=_gas_bound(ceil(to_hash.typ.maxlen / 32)),
)
)

0 comments on commit 2c73f83

Please sign in to comment.