diff --git a/starknet_py/hash/address.py b/starknet_py/hash/address.py index 62bc31014..08765979d 100644 --- a/starknet_py/hash/address.py +++ b/starknet_py/hash/address.py @@ -1,7 +1,13 @@ from typing import Sequence from starknet_py.constants import CONTRACT_ADDRESS_PREFIX, L2_ADDRESS_UPPER_BOUND -from starknet_py.hash.utils import compute_hash_on_elements +from starknet_py.hash.utils import ( + HEX_PREFIX, + _starknet_keccak, + compute_hash_on_elements, + encode_uint, + get_bytes_length, +) def compute_address( @@ -33,3 +39,29 @@ def compute_address( ) return raw_address % L2_ADDRESS_UPPER_BOUND + + +def get_checksum_address(address: str) -> str: + if not address.lower().startswith(HEX_PREFIX): + raise ValueError(f"{address} is not a valid hexadecimal address.") + + int_address = int(address, 16) + string_address = address[2:].zfill(64) + + address_in_bytes = encode_uint(int_address, get_bytes_length(int_address)) + address_hash = _starknet_keccak(address_in_bytes) + + result = "".join( + ( + char.upper() + if char.isalpha() and (address_hash >> 256 - 4 * i - 1) & 1 + else char + ) + for i, char in enumerate(string_address) + ) + + return f"{HEX_PREFIX}{result}" + + +def is_checksum_address(address: str) -> bool: + return get_checksum_address(address) == address diff --git a/starknet_py/hash/address_test.py b/starknet_py/hash/address_test.py index af90f36d6..393d90577 100644 --- a/starknet_py/hash/address_test.py +++ b/starknet_py/hash/address_test.py @@ -1,4 +1,10 @@ -from starknet_py.hash.address import compute_address +import pytest + +from starknet_py.hash.address import ( + compute_address, + get_checksum_address, + is_checksum_address, +) def test_compute_address(): @@ -22,3 +28,46 @@ def test_compute_address_with_deployer_address(): ) == 3179899882984850239687045389724311807765146621017486664543269641150383510696 ) + + +@pytest.mark.parametrize( + "address, checksum_address", + [ + ( + "0x2fd23d9182193775423497fc0c472e156c57c69e4089a1967fb288a2d84e914", + "0x02Fd23d9182193775423497fc0c472E156C57C69E4089A1967fb288A2d84e914", + ), + ( + "0x00abcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefabcdefab", + "0x00AbcDefaBcdefabCDEfAbCDEfAbcdEFAbCDEfabCDefaBCdEFaBcDeFaBcDefAb", + ), + ( + "0xfedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafedcbafe", + "0x00fEdCBafEdcbafEDCbAFedCBAFeDCbafEdCBAfeDcbaFeDCbAfEDCbAfeDcbAFE", + ), + ("0xa", "0x000000000000000000000000000000000000000000000000000000000000000A"), + ( + "0x0", + "0x0000000000000000000000000000000000000000000000000000000000000000", + ), + ], +) +def test_get_checksum_address(address, checksum_address): + assert get_checksum_address(address) == checksum_address + + +@pytest.mark.parametrize("address", ["", "0xx", "0123"]) +def test_get_checksum_address_raises_on_invalid_address(address): + with pytest.raises(ValueError): + get_checksum_address(address) + + +@pytest.mark.parametrize( + "address, is_checksum", + [ + ("0x02Fd23d9182193775423497fc0c472E156C57C69E4089A1967fb288A2d84e914", True), + ("0x000000000000000000000000000000000000000000000000000000000000000a", False), + ], +) +def test_is_checksum_address(address, is_checksum): + assert is_checksum_address(address) == is_checksum diff --git a/starknet_py/hash/utils.py b/starknet_py/hash/utils.py index 5caa7a447..7f2c563dc 100644 --- a/starknet_py/hash/utils.py +++ b/starknet_py/hash/utils.py @@ -14,6 +14,7 @@ from starknet_py.constants import EC_ORDER MASK_250 = 2**250 - 1 +HEX_PREFIX = "0x" def _starknet_keccak(data: bytes) -> int: @@ -84,3 +85,7 @@ def encode_uint(value: int, bytes_length: int = 32) -> bytes: def encode_uint_list(data: List[int]) -> bytes: return b"".join(encode_uint(x) for x in data) + + +def get_bytes_length(value: int) -> int: + return (value.bit_length() + 7) // 8