diff --git a/manticore/native/manticore.py b/manticore/native/manticore.py index 95827b706..0d6620409 100644 --- a/manticore/native/manticore.py +++ b/manticore/native/manticore.py @@ -5,11 +5,12 @@ import os import shlex import time +from typing import Callable, Optional import sys from elftools.elf.elffile import ELFFile from elftools.elf.sections import SymbolTableSection -from .state import State +from .state import HookCallback, State from ..core.manticore import ManticoreBase from ..core.smtlib import ConstraintSet from ..core.smtlib.solver import SelectedSolver, issymbolic @@ -229,19 +230,28 @@ def decorator(f): return decorator - def add_hook(self, pc, callback, after=False): + def add_hook( + self, + pc: Optional[int], + callback: HookCallback, + after: bool = False, + state: Optional[State] = None, + ): """ Add a callback to be invoked on executing a program counter. Pass `None` for pc to invoke callback on every instruction. `callback` should be a callable that takes one :class:`~manticore.core.state.State` argument. :param pc: Address of instruction to hook - :type pc: int or None - :param callable callback: Hook function + :param callback: Hook function + :param after: Hook after PC executes? + :param state: Optionally, add hook for this state only, else all states """ if not (isinstance(pc, int) or pc is None): raise TypeError(f"pc must be either an int or None, not {pc.__class__.__name__}") - else: + + if state is None: + # add hook to all states hooks, when, hook_callback = ( (self._hooks, "will_execute_instruction", self._hook_callback) if not after @@ -250,6 +260,9 @@ def add_hook(self, pc, callback, after=False): hooks.setdefault(pc, set()).add(callback) if hooks: self.subscribe(when, hook_callback) + else: + # only hook for the specified state + state.add_hook(pc, callback, after) def _hook_callback(self, state, pc, instruction): "Invoke all registered generic hooks" diff --git a/manticore/native/state.py b/manticore/native/state.py index da0f6930b..7dd0289f0 100644 --- a/manticore/native/state.py +++ b/manticore/native/state.py @@ -1,8 +1,15 @@ +import copy from collections import namedtuple -from typing import Any, NamedTuple +from typing import Any, Callable, Dict, NamedTuple, Optional, Set, Tuple, Union +from .cpu.disasm import Instruction +from .memory import ConcretizeMemory, MemoryException +from .. import issymbolic from ..core.state import StateBase, Concretize, TerminateState -from ..native.memory import ConcretizeMemory, MemoryException +from ..core.smtlib import Expression + + +HookCallback = Callable[[StateBase], None] class CheckpointData(NamedTuple): @@ -11,6 +18,131 @@ class CheckpointData(NamedTuple): class State(StateBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._hooks: Dict[Optional[int], Set[HookCallback]] = {} + self._after_hooks: Dict[Optional[int], Set[HookCallback]] = {} + + def __getstate__(self) -> Dict[str, Any]: + state = super().__getstate__() + state["hooks"] = self._hooks + state["after_hooks"] = self._after_hooks + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + super().__setstate__(state) + self._hooks = state["hooks"] + self._after_hooks = state["after_hooks"] + self._resub_hooks() + + def __enter__(self) -> "State": + new_state = super().__enter__() + new_state._hooks = copy.copy(self._hooks) + new_state._after_hooks = copy.copy(self._after_hooks) + return new_state + + def _get_hook_context( + self, after: bool = True + ) -> Tuple[Dict[Optional[int], Set[HookCallback]], str, Any]: + """ + Internal helper function to get hook context information. + + :param after: Whether we want info pertaining to hooks after instruction executes or before + :return: Information for hooks after or before: + - set of hooks for specified after or before + - string of callback event + - State function that handles the callback + """ + return ( + (self._hooks, "will_execute_instruction", self._state_hook_callback) + if not after + else (self._after_hooks, "did_execute_instruction", self._state_after_hook_callback) + ) + + def remove_hook(self, pc: Optional[int], callback: HookCallback, after: bool = False) -> bool: + """ + Remove a callback with the specified properties + :param pc: Address of instruction to remove from + :param callback: The callback function that was at the address + :param after: Whether it was after instruction executed or not + :return: Whether it was removed + """ + hooks, when, _ = self._get_hook_context(after) + cbs = hooks.get(pc, set()) + if callback in cbs: + cbs.remove(callback) + else: + return False + + if len(hooks.get(pc, set())) == 0: + del hooks[pc] + + return True + + def add_hook(self, pc: Optional[int], callback: HookCallback, after: bool = False) -> None: + """ + Add a callback to be invoked on executing a program counter. Pass `None` + for pc to invoke callback on every instruction. `callback` should be a callable + that takes one :class:`~manticore.native.state.State` argument. + + :param pc: Address of instruction to hook + :param callback: Hook function + :param after: Hook after PC executes? + :param state: Add hook to this state + """ + hooks, when, hook_callback = self._get_hook_context(after) + hooks.setdefault(pc, set()).add(callback) + if hooks: + self.subscribe(when, hook_callback) + + def _resub_hooks(self) -> None: + """ + Internal helper function to resubscribe hook callback events when the + state is active again. + """ + # TODO: check if the lists actually have hooks + _, when, hook_callback = self._get_hook_context(False) + self.subscribe(when, hook_callback) + + _, when, hook_callback = self._get_hook_context(True) + self.subscribe(when, hook_callback) + + def _state_hook_callback(self, pc: int, _instruction: Instruction) -> None: + """ + Invoke all registered State hooks before the instruction executes. + + :param pc: Address where the hook should run + :param _instruction: Instruction at this PC + """ + # Prevent crash if removing hook(s) during a callback + tmp_hooks = copy.deepcopy(self._hooks) + + # Invoke all pc-specific hooks + for cb in tmp_hooks.get(pc, []): + cb(self) + + # Invoke all pc-agnostic hooks + for cb in tmp_hooks.get(None, []): + cb(self) + + def _state_after_hook_callback(self, last_pc: int, _pc: int, _instruction: Instruction): + """ + Invoke all registered State hooks after the instruction executes. + + :param last_pc: Address where the hook should run after instruction execution + :param _pc: Next address to execute + :param _instruction: Instruction at this last_pc + """ + # Prevent crash if removing hook(s) during a callback + tmp_hooks = copy.deepcopy(self._after_hooks) + # Invoke all pc-specific hooks + for cb in tmp_hooks.get(last_pc, []): + cb(self) + + # Invoke all pc-agnostic hooks + for cb in tmp_hooks.get(None, []): + cb(self) + @property def cpu(self): """ diff --git a/tests/native/test_state.py b/tests/native/test_state.py index ac7df6e09..93597ec2d 100644 --- a/tests/native/test_state.py +++ b/tests/native/test_state.py @@ -1,6 +1,9 @@ +import io import unittest import os +from contextlib import redirect_stdout +from manticore.core.state import StateBase from manticore.utils.event import Eventful from manticore.platforms import linux from manticore.native.state import State @@ -157,6 +160,63 @@ def test_tainted_symbolic_value(self): expr = self.state.new_symbolic_value(64, taint=taint) self.assertEqual(expr.taint, frozenset(taint)) + def test_state_hook(self): + initial_state = State(ConstraintSet(), FakePlatform()) + + def fake_hook(_: StateBase) -> None: + return None + + self.assertTrue(len(initial_state._hooks) == 0) + self.assertTrue(len(initial_state._after_hooks) == 0) + + # This hook should be propagated to child state + initial_state.add_hook(0x4000, fake_hook, after=False) + + self.assertTrue(len(initial_state._hooks) == 1) + self.assertTrue(len(initial_state._after_hooks) == 0) + + with initial_state as new_state: + # Child state has parent's hook + self.assertTrue(len(new_state._hooks) == 1) + self.assertTrue(len(new_state._after_hooks) == 0) + + # Try adding the same hook + new_state.add_hook(0x4000, fake_hook, after=False) + # Should not add again + self.assertTrue(len(new_state._hooks) == 1) + + # Add two hooks for after and before instruction + new_state.add_hook(0x4001, fake_hook, after=True) + new_state.add_hook(0x4001, fake_hook, after=False) + + # A new hook added to both lists + self.assertTrue(len(new_state._hooks) == 2) + self.assertTrue(len(new_state._after_hooks) == 1) + + # Ensure parent state was not affected + self.assertTrue(len(initial_state._hooks) == 1) + self.assertTrue(len(initial_state._after_hooks) == 0) + + # Remove one of the hooks we added + new_state.remove_hook(0x4000, fake_hook, after=False) + # Try to remove a non-existent hook + self.assertFalse(new_state.remove_hook(0x4000, fake_hook, after=True)) + + # Ensure removal + self.assertTrue(len(new_state._hooks) == 1) + self.assertTrue(len(new_state._after_hooks) == 1) + + # Ensure parent state wasn't affected + self.assertTrue(len(initial_state._hooks) == 1) + self.assertTrue(len(initial_state._after_hooks) == 0) + + # Add hook to all PC in our parent state + initial_state.add_hook(None, fake_hook, after=True) + + # Ensure only the hooks we added are still here + self.assertTrue(len(initial_state._hooks) == 1) + self.assertTrue(len(initial_state._after_hooks) == 1) + def testContextSerialization(self): import pickle as pickle @@ -211,6 +271,55 @@ def testContextSerialization(self): self.assertEqual(new_new_state.context["step"], 30) +""" +This function needs to be a global function for the following test or else we +get the following error + E AttributeError: Can't pickle local object 'StateHooks.test_state_hooks..do_nothing' +""" + + +def do_nothing(_: StateBase) -> None: + return None + + +def fin(_: StateBase) -> None: + print("Reached fin callback") + return None + + +class StateHooks(unittest.TestCase): + def setUp(self): + core = config.get_group("core") + core.seed = 61 + core.mprocessing = core.mprocessing.single + + dirname = os.path.dirname(__file__) + self.m = Manticore(os.path.join(dirname, "binaries", "basic_linux_amd64"), policy="random") + + def test_state_hooks(self): + @self.m.hook(0x400610, after=True) + def process_hook(state: State) -> None: + # We can't remove because the globally applied hooks are stored in + # the Manticore class, not State + self.assertFalse(state.remove_hook(0x400610, process_hook, after=True)) + # We can remove this one because it was applied specifically to this + # State (or its parent) + self.assertTrue(state.remove_hook(None, do_nothing, after=True)) + + state.add_hook(None, do_nothing, after=False) + state.add_hook(None, do_nothing, after=True) + state.add_hook(0x400647, fin, after=True) + state.add_hook(0x400647, fin, after=False) + + for state in self.m.ready_states: + self.m.add_hook(None, do_nothing, after=True, state=state) + + f = io.StringIO() + with redirect_stdout(f): + self.m.run() + self.assertIn("Reached fin callback", f.getvalue()) + + class StateMergeTest(unittest.TestCase): # Need to add a plugin that counts the number of states in did_fork_state, and records the max