diff --git a/kmir/src/kmir/__main__.py b/kmir/src/kmir/__main__.py index bcc3adbae..846cf91ef 100644 --- a/kmir/src/kmir/__main__.py +++ b/kmir/src/kmir/__main__.py @@ -263,6 +263,12 @@ def _arg_parser() -> ArgumentParser: prove_args.add_argument( '--break-on-thunk', dest='break_on_thunk', action='store_true', help='Break on thunk evaluation' ) + prove_args.add_argument( + '--terminate-on-thunk', + dest='terminate_on_thunk', + action='store_true', + help='Terminate proof when reaching a thunk', + ) prove_args.add_argument( '--break-every-statement', dest='break_every_statement', @@ -507,6 +513,7 @@ def _parse_args(ns: Namespace) -> KMirOpts: break_on_terminator_unreachable=ns.break_on_terminator_unreachable, break_every_terminator=ns.break_every_terminator, break_every_step=ns.break_every_step, + terminate_on_thunk=ns.terminate_on_thunk, ) case 'link': return LinkOpts( diff --git a/kmir/src/kmir/kmir.py b/kmir/src/kmir/kmir.py index 747bc770e..c5d629c62 100644 --- a/kmir/src/kmir/kmir.py +++ b/kmir/src/kmir/kmir.py @@ -9,7 +9,7 @@ from pyk.cli.utils import bug_report_arg from pyk.cterm import CTerm, cterm_symbolic -from pyk.kast.inner import KApply, KSequence, KSort, KToken, KVariable, Subst +from pyk.kast.inner import KApply, KLabel, KSequence, KSort, KToken, KVariable, Subst from pyk.kast.manip import abstract_term_safely, split_config_from from pyk.kcfg import KCFG from pyk.kcfg.explore import KCFGExplore @@ -128,13 +128,14 @@ def from_kompiled_kore( class Symbols: END_PROGRAM: Final = KApply('#EndProgram_KMIR-CONTROL-FLOW_KItem') + THUNK: Final = KLabel('thunk(_)_RT-DATA_Value_Evaluation') @cached_property def parser(self) -> Parser: return Parser(self.definition) @contextmanager - def kcfg_explore(self, label: str | None = None) -> Iterator[KCFGExplore]: + def kcfg_explore(self, label: str | None = None, terminate_on_thunk: bool = False) -> Iterator[KCFGExplore]: with cterm_symbolic( self.definition, self.definition_dir, @@ -143,7 +144,7 @@ def kcfg_explore(self, label: str | None = None) -> Iterator[KCFGExplore]: id=label if self.bug_report is not None else None, # NB bug report arg.s must be coherent simplify_each=30, ) as cts: - yield KCFGExplore(cts, kcfg_semantics=KMIRSemantics()) + yield KCFGExplore(cts, kcfg_semantics=KMIRSemantics(terminate_on_thunk=terminate_on_thunk)) def run_smir( self, @@ -247,7 +248,7 @@ def prove_rs(opts: ProveRSOpts) -> APRProof: break_on_calls=opts.break_on_calls, break_on_function_calls=opts.break_on_function_calls, break_on_intrinsic_calls=opts.break_on_intrinsic_calls, - break_on_thunk=opts.break_on_thunk, + break_on_thunk=opts.break_on_thunk or opts.terminate_on_thunk, # must break for terminal rule break_every_statement=opts.break_every_statement, break_on_terminator_goto=opts.break_on_terminator_goto, break_on_terminator_switch_int=opts.break_on_terminator_switch_int, @@ -260,15 +261,26 @@ def prove_rs(opts: ProveRSOpts) -> APRProof: break_every_step=opts.break_every_step, ) - with kmir.kcfg_explore(label) as kcfg_explore: + with kmir.kcfg_explore(label, terminate_on_thunk=opts.terminate_on_thunk) as kcfg_explore: prover = APRProver(kcfg_explore, execute_depth=opts.max_depth, cut_point_rules=cut_point_rules) prover.advance_proof(apr_proof, max_iterations=opts.max_iterations) return apr_proof class KMIRSemantics(DefaultSemantics): + terminate_on_thunk: bool + + def __init__(self, terminate_on_thunk: bool = False) -> None: + self.terminate_on_thunk = terminate_on_thunk + def is_terminal(self, cterm: CTerm) -> bool: k_cell = cterm.cell('K_CELL') + + if self.terminate_on_thunk: # terminate on `thunk ( ... )` rule + match k_cell: + case KApply(label, _) | KSequence((KApply(label, _), *_)) if label == KMIR.Symbols.THUNK: + return True + # #EndProgram if k_cell == KMIR.Symbols.END_PROGRAM: return True diff --git a/kmir/src/kmir/options.py b/kmir/src/kmir/options.py index cda622524..ff6371fdd 100644 --- a/kmir/src/kmir/options.py +++ b/kmir/src/kmir/options.py @@ -55,6 +55,7 @@ class ProveOpts(KMirOpts): break_on_terminator_unreachable: bool break_every_terminator: bool break_every_step: bool + terminate_on_thunk: bool def __init__( self, @@ -77,6 +78,7 @@ def __init__( break_on_terminator_unreachable: bool = False, break_every_terminator: bool = False, break_every_step: bool = False, + terminate_on_thunk: bool = False, ) -> None: self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None self.bug_report = bug_report @@ -97,6 +99,7 @@ def __init__( self.break_on_terminator_unreachable = break_on_terminator_unreachable self.break_every_terminator = break_every_terminator self.break_every_step = break_every_step + self.terminate_on_thunk = terminate_on_thunk @dataclass @@ -131,6 +134,7 @@ def __init__( break_on_terminator_unreachable: bool = False, break_every_terminator: bool = False, break_every_step: bool = False, + terminate_on_thunk: bool = False, ) -> None: self.rs_file = rs_file self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None @@ -155,6 +159,7 @@ def __init__( self.break_on_terminator_unreachable = break_on_terminator_unreachable self.break_every_terminator = break_every_terminator self.break_every_step = break_every_step + self.terminate_on_thunk = terminate_on_thunk @dataclass