From 9b0302512b9b292ce65e4241149c81156ad87e99 Mon Sep 17 00:00:00 2001 From: Brad Larsen Date: Thu, 16 Apr 2020 18:21:20 -0400 Subject: [PATCH] Add type hints to several parts of Manticore (#1667) * Add type hints * Add type hints in manticore/core/smtlib/solver.py * Add type hints to the `unsigned_hexlify` helper function * Add type hints & more robust error checking to a Solver internal method * Add type hints in native memory code * Add type hints to Linux platform test methods * Add type hints for Linux `sys_access` * Add type hints for AbstractCPU; rephrase a couple bits to appease mypy * Add type hints to manticore/platforms/linux.py * Add type hints for Workspace members * Add type hints for `Operand` * Add a return type hint per @ekilmer Co-Authored-By: Eric Kilmer --- manticore/core/smtlib/constraints.py | 10 ++--- manticore/core/smtlib/solver.py | 15 +++++--- manticore/core/workspace.py | 8 ++-- manticore/native/cpu/abstractcpu.py | 25 ++++++------- manticore/native/memory.py | 56 +++++++++++++--------------- manticore/platforms/linux.py | 21 +++++------ tests/ethereum/test_detectors.py | 6 +-- tests/native/test_linux.py | 35 +++++++++-------- 8 files changed, 85 insertions(+), 91 deletions(-) diff --git a/manticore/core/smtlib/constraints.py b/manticore/core/smtlib/constraints.py index 37240996c..0d0ee1793 100644 --- a/manticore/core/smtlib/constraints.py +++ b/manticore/core/smtlib/constraints.py @@ -56,7 +56,7 @@ def __reduce__(self): }, ) - def __enter__(self): + def __enter__(self) -> "ConstraintSet": assert self._child is None self._child = self.__class__() self._child._parent = self @@ -64,11 +64,11 @@ def __enter__(self): self._child._declarations = dict(self._declarations) return self._child - def __exit__(self, ty, value, traceback): + def __exit__(self, ty, value, traceback) -> None: self._child._parent = None self._child = None - def __len__(self): + def __len__(self) -> int: if self._parent is not None: return len(self._constraints) + len(self._parent) return len(self._constraints) @@ -107,7 +107,7 @@ def add(self, constraint, check=False): if not solver.check(self): raise ValueError("Added an impossible constraint") - def _get_sid(self): + def _get_sid(self) -> int: """ Returns a unique id. """ assert self._child is None self._sid += 1 @@ -263,7 +263,7 @@ def _make_unique_name(self, name="VAR"): name = f"{name}_{self._get_sid()}" return name - def is_declared(self, expression_var): + def is_declared(self, expression_var) -> bool: """ True if expression_var is declared in this constraint set """ if not isinstance(expression_var, Variable): raise ValueError(f"Expression must be a Variable (not a {type(expression_var)})") diff --git a/manticore/core/smtlib/solver.py b/manticore/core/smtlib/solver.py index 78435d199..caf30794a 100644 --- a/manticore/core/smtlib/solver.py +++ b/manticore/core/smtlib/solver.py @@ -292,7 +292,7 @@ def _reset(self, constraints=None): if constraints is not None: self._send(constraints) - def _send(self, cmd: Union[str, ConstraintSet]): + def _send(self, cmd: str) -> None: """ Send a string to the solver. @@ -330,8 +330,13 @@ def _recv(self) -> str: return buf - def __readline_and_count(self): - buf = self._proc.stdout.readline() + def __readline_and_count(self) -> Tuple[str, int, int]: + stdout = self._proc.stdout + if stdout is None: + raise SolverError("Could not read from stdout: file descriptor is None") + buf = stdout.readline() + if buf is None: + raise SolverError("Could not read from stdout") return buf, buf.count("("), buf.count(")") # UTILS: check-sat get-value @@ -411,7 +416,7 @@ def _pop(self): """Recall the last pushed constraint store and state.""" self._send("(pop 1)") - def can_be_true(self, constraints: ConstraintSet, expression=True): + def can_be_true(self, constraints: ConstraintSet, expression: Union[bool, Bool] = True) -> bool: """Check if two potentially symbolic values can be equal""" if isinstance(expression, bool): if not expression: @@ -592,7 +597,7 @@ def optimize(self, constraints: ConstraintSet, x: BitVec, goal: str, max_iter=10 return last_value raise SolverError("Optimizing error, unsat or unknown core") - def get_value(self, constraints, *expressions): + def get_value(self, constraints: ConstraintSet, *expressions): """ Ask the solver for one possible result of given expressions using given set of constraints. diff --git a/manticore/core/workspace.py b/manticore/core/workspace.py index e31d1042f..d0031b52c 100644 --- a/manticore/core/workspace.py +++ b/manticore/core/workspace.py @@ -421,12 +421,12 @@ class Workspace: def __init__(self, store_or_desc=None): if isinstance(store_or_desc, Store): - self._store = store_or_desc + self._store: Store = store_or_desc else: self._store = Store.fromdescriptor(store_or_desc) - self._serializer = PickleSerializer() - self._prefix = "state_" - self._suffix = ".pkl" + self._serializer: StateSerializer = PickleSerializer() + self._prefix: str = "state_" + self._suffix: str = ".pkl" @property def uri(self): diff --git a/manticore/native/cpu/abstractcpu.py b/manticore/native/cpu/abstractcpu.py index 3a1b36a64..3af0bfb2d 100644 --- a/manticore/native/cpu/abstractcpu.py +++ b/manticore/native/cpu/abstractcpu.py @@ -10,7 +10,7 @@ from .disasm import init_disassembler from ..memory import ConcretizeMemory, InvalidMemoryAccess, FileMap, AnonMap -from ..memory import LazySMemory +from ..memory import LazySMemory, Memory from ...core.smtlib import Operators, Constant, issymbolic from ...core.smtlib import visitors from ...core.smtlib.solver import Z3Solver @@ -264,9 +264,9 @@ class Abi: Used for function call and system call models. """ - def __init__(self, cpu): + def __init__(self, cpu: "Cpu"): """ - :param manticore.core.cpu.Cpu cpu: CPU to initialize with + :param CPU to initialize with """ self._cpu = cpu @@ -392,7 +392,7 @@ def invoke(self, model, prefix_args=None): platform_logger = logging.getLogger("manticore.platforms.platform") -def unsigned_hexlify(i): +def unsigned_hexlify(i: Any) -> Any: if type(i) is int: if i < 0: return hex((1 << 64) + i) @@ -497,7 +497,7 @@ class Cpu(Eventful): "execute_syscall", } - def __init__(self, regfile: RegisterFile, memory, **kwargs): + def __init__(self, regfile: RegisterFile, memory: Memory, **kwargs): assert isinstance(regfile, RegisterFile) self._disasm = kwargs.pop("disasm", "capstone") super().__init__(**kwargs) @@ -645,7 +645,7 @@ def emulate_until(self, target: int): ############################# # Memory access @property - def memory(self): + def memory(self) -> Memory: return self._memory def write_int(self, where, expression, size=None, force=False): @@ -680,8 +680,7 @@ def _raw_read(self, where: int, size=1) -> bytes: """ map = self.memory.map_containing(where) start = map._get_offset(where) - mapType = type(map) - if mapType is FileMap: + if isinstance(map, FileMap): end = map._get_offset(where + size) if end > map._mapped_size: @@ -699,7 +698,7 @@ def _raw_read(self, where: int, size=1) -> bytes: data += map._overlay[offset] data += raw_data[len(data) :] - elif mapType is AnonMap: + elif isinstance(map, AnonMap): data = bytes(map._data[start : start + size]) else: data = b"".join(self.memory[where : where + size]) @@ -743,15 +742,13 @@ def write_bytes(self, where: int, data, force: bool = False) -> None: # At the very least, using it in non-concrete mode will break the symbolic strcmp/strlen models. The 1024 byte # minimum is intended to minimize the potential effects of this by ensuring that if there _are_ any other # issues, they'll only crop up when we're doing very large writes, which are fairly uncommon. - can_write_raw = ( - type(mp) is AnonMap + if ( + isinstance(mp, AnonMap) and isinstance(data, (str, bytes)) and (mp.end - mp.start + 1) >= len(data) >= 1024 and not issymbolic(data) and self._concrete - ) - - if can_write_raw: + ): logger.debug("Using fast write") offset = mp._get_offset(where) if isinstance(data, str): diff --git a/manticore/native/memory.py b/manticore/native/memory.py index 5ccc14e4e..167e0fe64 100644 --- a/manticore/native/memory.py +++ b/manticore/native/memory.py @@ -16,11 +16,11 @@ from ..utils.helpers import interval_intersection from ..utils import config -from typing import Dict, Optional - import functools import logging +from typing import Dict, Generator, Iterable, List, MutableMapping, Optional, Set + logger = logging.getLogger(__name__) consts = config.get_group("native") @@ -37,7 +37,7 @@ class MemoryException(Exception): Memory exceptions """ - def __init__(self, message, address=None): + def __init__(self, message: str, address=None): """ Builds a memory exception. @@ -73,10 +73,9 @@ def __init__(self, mem, address, size, message=None, policy="MINMAX"): class InvalidMemoryAccess(MemoryException): _message = "Invalid memory access" - def __init__(self, address, mode): + def __init__(self, address, mode: str): assert mode in "rwx" - suffix = f" (mode:{mode})" - message = self._message + suffix + message = f"{self._message} (mode:{mode})" super(InvalidMemoryAccess, self).__init__(message, address) self.mode = mode @@ -84,7 +83,7 @@ def __init__(self, address, mode): class InvalidSymbolicMemoryAccess(InvalidMemoryAccess): _message = "Invalid symbolic memory access" - def __init__(self, address, mode, size, constraint): + def __init__(self, address, mode: str, size, constraint): super(InvalidSymbolicMemoryAccess, self).__init__(address, mode) # the crashing constraint you need to assert self.constraint = constraint @@ -193,7 +192,7 @@ def __iter__(self): """ return iter(range(self._start, self._end)) - def __eq__(self, other): + def __eq__(self, other) -> bool: return ( self.start == other.start and self.end == other.end @@ -211,7 +210,7 @@ def __lt__(self, other): return self.perms < other.perms return self.name < other.name - def __hash__(self): + def __hash__(self) -> int: return object.__hash__(self) def _in_range(self, index) -> bool: @@ -577,18 +576,18 @@ class Memory(object, metaclass=ABCMeta): This class handles all virtual memory mappings and symbolic chunks. """ - def __init__(self, maps=None, cpu=StubCPU()): + def __init__(self, maps: Optional[Iterable[Map]] = None, cpu=StubCPU()): """ Builds a memory manager. """ super().__init__() if maps is None: - self._maps = set() + self._maps: Set[Map] = set() else: self._maps = set(maps) self.cpu = cpu - self._page2map = WeakValueDictionary() # {page -> ref{MAP}} - self._recording_stack = [] + self._page2map: MutableMapping[int, Map] = WeakValueDictionary() # {page -> ref{MAP}} + self._recording_stack: List = [] for m in self._maps: for i in range(self._page(m.start), self._page(m.end)): assert i not in self._page2map @@ -599,38 +598,37 @@ def __reduce__(self): @property @abstractmethod - def memory_bit_size(self): + def memory_bit_size(self) -> int: return 32 @property @abstractmethod - def page_bit_size(self): + def page_bit_size(self) -> int: return 12 @property - def memory_size(self): + def memory_size(self) -> int: return 1 << self.memory_bit_size @property - def page_size(self): + def page_size(self) -> int: return 1 << self.page_bit_size @property - def memory_mask(self): + def memory_mask(self) -> int: return self.memory_size - 1 @property - def page_mask(self): + def page_mask(self) -> int: return self.page_size - 1 @property - def maps(self): + def maps(self) -> Set[Map]: return self._maps def _ceil(self, address) -> int: """ Returns the smallest page boundary value not less than the address. - :rtype: int :param address: the address to calculate its ceil. :return: the ceil of C{address}. """ @@ -642,7 +640,6 @@ def _floor(self, address) -> int: :param address: the address to calculate its floor. :return: the floor of C{address}. - :rtype: int """ return address & ~self.page_mask @@ -652,7 +649,6 @@ def _page(self, address) -> int: :param address: the address to calculate its page number. :return: the page number of address. - :rtype: int """ return address >> self.page_bit_size @@ -665,7 +661,6 @@ def _search(self, size, start=None, counter=0) -> int: :param counter: internal parameter to know if all the memory was already scanned. :return: the address of an available space to map C{size} bytes. :raises MemoryException: if there is no space available to allocate the desired memory. - :rtype: int todo: Document what happens when you try to allocate something that goes round the address 32/64 bit representation. @@ -811,7 +806,7 @@ def _del(self, m: Map) -> None: # remove m from the maps set self._maps.remove(m) - def map_containing(self, address): + def map_containing(self, address: int) -> Map: """ Returns the L{MMap} object containing the address. @@ -843,7 +838,7 @@ def mappings(self): return sorted(result) - def __str__(self): + def __str__(self) -> str: return "\n".join( [ f'{start:016x}-{end:016x} {p:>4s} {offset:08x} {name or ""}' @@ -851,7 +846,7 @@ def __str__(self): ] ) - def proc_self_mappings(self): + def proc_self_mappings(self) -> List[ProcSelfMapInfo]: """ Returns a sorted list of all the mappings for this memory for /proc/self/maps. Device, inode, and private/shared permissions are unsupported. @@ -860,7 +855,6 @@ def proc_self_mappings(self): Pathname is substituted by filename :return: a list of mappings. - :rtype: list """ result = [] # TODO: Device, inode, and private/shared permissions are unsupported @@ -888,7 +882,7 @@ def __proc_self__(self): self.proc_self_mappings() return "\n".join([f"{m}" for m in self.proc_self_mappings()]) - def _maps_in_range(self, start, end): + def _maps_in_range(self, start: int, end: int) -> Generator[Map, None, None]: """ Generates the list of maps that overlaps with the range [start:end] """ @@ -1011,12 +1005,12 @@ def access_ok(self, index, access, force=False): return force or m.access_ok(access) # write and read potentially symbolic bytes at symbolic indexes - def read(self, addr, size, force=False): + def read(self, addr, size, force: bool = False) -> List[bytes]: if not self.access_ok(slice(addr, addr + size), "r", force): raise InvalidMemoryAccess(addr, "r") assert size > 0 - result = [] + result: List[bytes] = [] stop = addr + size p = addr while p < stop: diff --git a/manticore/platforms/linux.py b/manticore/platforms/linux.py index 356bc6125..a4f6858cf 100644 --- a/manticore/platforms/linux.py +++ b/manticore/platforms/linux.py @@ -1043,13 +1043,13 @@ def set_entry(self, entryPC): self.current.PC = elf_entry logger.debug(f"Entry point updated: {elf_entry:016x}") - def load(self, filename: str, env) -> None: + def load(self, filename: str, env_list: List) -> None: """ Loads and an ELF program in memory and prepares the initial CPU state. Creates the stack and loads the environment variables and the arguments in it. :param filename: pathname of the file to be executed. (used for auxv) - :param list env: A list of env variables. (used for extracting vars that control ld behavior) + :param env_list: A list of env variables. (used for extracting vars that control ld behavior) :raises error: - 'Not matching cpu': if the program is compiled for a different architecture - 'Not matching memory': if the program is compiled for a different address size @@ -1060,7 +1060,7 @@ def load(self, filename: str, env) -> None: cpu = self.current elf = self.elf arch = self.arch - env = dict(var.split("=", 1) for var in env if "=" in var) + env = dict(var.split("=", 1) for var in env_list if "=" in var) addressbitsize = {"x86": 32, "x64": 64, "ARM": 32, "AArch64": 64}[elf.get_machine_arch()] logger.debug("Loading %s as a %s elf", filename, arch) @@ -1497,7 +1497,7 @@ def sys_llseek( ) return -e.err - def sys_read(self, fd: int, buf, count) -> int: + def sys_read(self, fd: int, buf: int, count: int) -> int: data: bytes = bytes() if count != 0: # TODO check count bytes from buf @@ -1572,10 +1572,9 @@ def sys_fork(self) -> int: """ return -errno.ENOSYS - def sys_access(self, buf, mode) -> int: + def sys_access(self, buf: int, mode: int) -> int: """ Checks real user's permissions for a file - :rtype: int :param buf: a buffer containing the pathname to the file to check its permissions. :param mode: the access permissions to check. @@ -2259,7 +2258,7 @@ def sys_socket(self, domain, socket_type, protocol): fd = self._open(f) return fd - def _is_sockfd(self, sockfd): + def _is_sockfd(self, sockfd: int) -> int: try: fd = self.files[sockfd] if not isinstance(fd, SocketDesc): @@ -2268,19 +2267,19 @@ def _is_sockfd(self, sockfd): except IndexError: return -errno.EBADF - def sys_bind(self, sockfd, address, address_len): + def sys_bind(self, sockfd: int, address, address_len) -> int: return self._is_sockfd(sockfd) - def sys_listen(self, sockfd, backlog): + def sys_listen(self, sockfd: int, backlog) -> int: return self._is_sockfd(sockfd) - def sys_accept(self, sockfd, addr, addrlen): + def sys_accept(self, sockfd: int, addr, addrlen) -> int: """ https://github.com/torvalds/linux/blob/63bdf4284c38a48af21745ceb148a087b190cd21/net/socket.c#L1649-L1653 """ return self.sys_accept4(sockfd, addr, addrlen, 0) - def sys_accept4(self, sockfd, addr, addrlen, flags): + def sys_accept4(self, sockfd: int, addr, addrlen, flags) -> int: # TODO: ehennenfent - Only handles the flags=0 (sys_accept) case ret = self._is_sockfd(sockfd) if ret != 0: diff --git a/tests/ethereum/test_detectors.py b/tests/ethereum/test_detectors.py index 54c3d6771..acd795ea1 100644 --- a/tests/ethereum/test_detectors.py +++ b/tests/ethereum/test_detectors.py @@ -27,7 +27,7 @@ from manticore.utils import config, log -from typing import Type +from typing import Tuple, Type consts = config.get_group("core") consts.mprocessing = consts.mprocessing.single @@ -55,7 +55,7 @@ def tearDown(self): self.mevm = None shutil.rmtree(self.worksp) - def _test(self, name, should_find, use_ctor_sym_arg=False): + def _test(self, name: str, should_find, use_ctor_sym_arg=False): """ Tests DetectInvalid over the consensys benchmark suit """ @@ -65,7 +65,7 @@ def _test(self, name, should_find, use_ctor_sym_arg=False): filepath = os.path.join(dir, f"{name}.sol") if use_ctor_sym_arg: - ctor_arg = (mevm.make_symbolic_value(),) + ctor_arg: Tuple = (mevm.make_symbolic_value(),) else: ctor_arg = () diff --git a/tests/native/test_linux.py b/tests/native/test_linux.py index 7e3df7535..79a1b2a8d 100644 --- a/tests/native/test_linux.py +++ b/tests/native/test_linux.py @@ -21,26 +21,26 @@ class LinuxTest(unittest.TestCase): _multiprocess_can_split_ = True BIN_PATH = os.path.join(os.path.dirname(__file__), "binaries", "basic_linux_amd64") - def setUp(self): + def setUp(self) -> None: self.linux = linux.Linux(self.BIN_PATH) self.symbolic_linux_armv7 = linux.SLinux.empty_platform("armv7") self.symbolic_linux_aarch64 = linux.SLinux.empty_platform("aarch64") - def tearDown(self): + def tearDown(self) -> None: for f in ( self.linux.files + self.symbolic_linux_armv7.files + self.symbolic_linux_aarch64.files ): if isinstance(f, linux.File): f.close() - def test_regs_init_state_x86(self): + def test_regs_init_state_x86(self) -> None: x86_defaults = {"CS": 0x23, "SS": 0x2B, "DS": 0x2B, "ES": 0x2B} cpu = self.linux.current for reg, val in x86_defaults.items(): self.assertEqual(cpu.regfile.read(reg), val) - def test_stack_init(self): + def test_stack_init(self) -> None: argv = ["arg1", "arg2", "arg3"] real_argv = [self.BIN_PATH] + argv envp = ["env1", "env2", "env3"] @@ -58,7 +58,7 @@ def test_stack_init(self): for i, env in enumerate(envp): self.assertEqual(cpu.read_string(cpu.read_int(envp_ptr + i * 8)), env) - def test_load_maps(self): + def test_load_maps(self) -> None: mappings = self.linux.current.memory.mappings() # stack should be last @@ -73,7 +73,7 @@ def test_load_maps(self): self.assertEqual(first_map_name, "basic_linux_amd64") self.assertEqual(second_map_name, "basic_linux_amd64") - def test_load_proc_self_maps(self): + def test_load_proc_self_maps(self) -> None: proc_maps = self.linux.current.memory.proc_self_mappings() # check that proc self raises error when not being read created as read only @@ -107,7 +107,7 @@ def test_load_proc_self_maps(self): None, ) - def test_aarch64_syscall_write(self): + def test_aarch64_syscall_write(self) -> None: nr_write = 64 # Create a minimal state. @@ -138,7 +138,7 @@ def test_aarch64_syscall_write(self): self.assertEqual(res, s) @unittest.skip("Stat differs in different test environments") - def test_armv7_syscall_fstat(self): + def test_armv7_syscall_fstat(self) -> None: nr_fstat64 = 197 # Create a minimal state @@ -162,7 +162,7 @@ def test_armv7_syscall_fstat(self): hexlify(b"".join(platform.current.read_bytes(stat, 100))), ) - def test_armv7_linux_symbolic_files_workspace_files(self): + def test_armv7_linux_symbolic_files_workspace_files(self) -> None: fname = "symfile" platform = self.symbolic_linux_armv7 @@ -192,7 +192,7 @@ def test_armv7_linux_symbolic_files_workspace_files(self): self.assertIn(fname, files) self.assertEqual(len(files[fname]), 1) - def test_armv7_linux_workspace_files(self): + def test_armv7_linux_workspace_files(self) -> None: platform = self.symbolic_linux_armv7 platform.argv = ["arg1", "arg2"] @@ -206,7 +206,7 @@ def test_armv7_linux_workspace_files(self): self.assertIn("stderr", files) self.assertIn("net", files) - def test_armv7_syscall_events(self): + def test_armv7_syscall_events(self) -> None: nr_fstat64 = 197 class Receiver: @@ -276,7 +276,7 @@ def _armv7_create_openat_state(self): return platform, dir_path - def test_armv7_syscall_openat_concrete(self): + def test_armv7_syscall_openat_concrete(self) -> None: platform, temp_dir = self._armv7_create_openat_state() try: platform.syscall() @@ -284,7 +284,7 @@ def test_armv7_syscall_openat_concrete(self): finally: shutil.rmtree(temp_dir) - def test_armv7_syscall_openat_symbolic(self): + def test_armv7_syscall_openat_symbolic(self) -> None: platform, temp_dir = self._armv7_create_openat_state() try: platform.current.R0 = BitVecVariable(32, "fd") @@ -302,7 +302,7 @@ def test_armv7_syscall_openat_symbolic(self): finally: shutil.rmtree(temp_dir) - def test_armv7_chroot(self): + def test_armv7_chroot(self) -> None: # Create a minimal state platform = self.symbolic_linux_armv7 platform.current.memory.mmap(0x1000, 0x1000, "rw ") @@ -320,8 +320,7 @@ def test_armv7_chroot(self): fd = platform.sys_chroot(path) self.assertEqual(fd, -errno.EPERM) - def test_symbolic_argv_envp(self): - + def test_symbolic_argv_envp(self) -> None: dirname = os.path.dirname(__file__) self.m = Manticore.linux( os.path.join(dirname, "binaries", "arguments_linux_amd64"), @@ -341,7 +340,7 @@ def test_symbolic_argv_envp(self): self.assertEqual(mem[6], b"\0") self.assertTrue(issymbolic(mem[5])) - def test_serialize_state_with_closed_files(self): + def test_serialize_state_with_closed_files(self) -> None: # regression test: issue 954 platform = self.linux @@ -350,7 +349,7 @@ def test_serialize_state_with_closed_files(self): platform.sys_close(fd) pickle_dumps(platform) - def test_thumb_mode_entrypoint(self): + def test_thumb_mode_entrypoint(self) -> None: # thumb_mode_entrypoint is a binary with only one instruction # 0x1000: add.w r0, r1, r2 # which is a Thumb instruction, so the entrypoint is set to 0x1001