Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 10 additions & 232 deletions kmir/src/kmir/kmir.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import shutil
import tempfile
from contextlib import contextmanager
from functools import cached_property
Expand All @@ -13,8 +12,6 @@
from pyk.kast.inner import KApply, KSequence, KSort, KToken, KVariable, Subst
from pyk.kast.manip import abstract_term_safely, free_vars, split_config_from
from pyk.kast.prelude.collections import list_empty, list_of
from pyk.kast.prelude.kint import intToken
from pyk.kast.prelude.utils import token
from pyk.kcfg import KCFG
from pyk.kcfg.explore import KCFGExplore
from pyk.kcfg.semantics import DefaultSemantics
Expand All @@ -24,7 +21,6 @@
from pyk.proof.reachability import APRProof, APRProver
from pyk.proof.show import APRProofNodePrinter

from .build import HASKELL_DEF_DIR, LLVM_DEF_DIR, LLVM_LIB_DIR
from .cargo import cargo_get_smir_json
from .kast import mk_call_terminator, symbolic_locals
from .kparse import KParse
Expand All @@ -33,7 +29,7 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any, Final
from typing import Final

from pyk.cterm.show import CTermShow
from pyk.kast.inner import KInner
Expand Down Expand Up @@ -62,127 +58,15 @@ def __init__(
def from_kompiled_kore(
smir_info: SMIRInfo, target_dir: str, bug_report: Path | None = None, symbolic: bool = True
) -> KMIR:
kmir = KMIR(HASKELL_DEF_DIR)

def _insert_rules_and_write(input_file: Path, rules: list[str], output_file: Path) -> None:
with open(input_file, 'r') as f:
lines = f.readlines()

# last line must start with 'endmodule'
last_line = lines[-1]
assert last_line.startswith('endmodule')
new_lines = lines[:-1]

# Insert rules before the endmodule line
new_lines.append(f'\n// Generated from file {input_file}\n\n')
new_lines.extend(rules)
new_lines.append('\n' + last_line)

# Write to output file
with open(output_file, 'w') as f:
f.writelines(new_lines)

target_path = Path(target_dir)
# TODO if target dir exists and contains files, check file dates (definition files and interpreter)
# to decide whether or not to recompile. For now we always recompile.
target_path.mkdir(parents=True, exist_ok=True)

rules = kmir.make_kore_rules(smir_info)
_LOGGER.info(f'Generated {len(rules)} function equations to add to `definition.kore')

if symbolic:
# Create output directories
target_llvm_path = target_path / 'llvm-library'
target_llvmdt_path = target_llvm_path / 'dt'
target_hs_path = target_path / 'haskell'

_LOGGER.info(f'Creating directories {target_llvmdt_path} and {target_hs_path}')
target_llvmdt_path.mkdir(parents=True, exist_ok=True)
target_hs_path.mkdir(parents=True, exist_ok=True)

# Process LLVM definition
_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_def_output = target_llvm_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, rules, llvm_def_output)

# Run llvm-kompile-matching and llvm-kompile for LLVM
# TODO use pyk to do this if possible (subprocess wrapper, maybe llvm-kompile itself?)
# TODO align compilation options to what we use in plugin.py
import subprocess

_LOGGER.info('Running llvm-kompile-matching')
subprocess.run(
['llvm-kompile-matching', str(llvm_def_output), 'qbaL', str(target_llvmdt_path), '1/2'], check=True
)
_LOGGER.info('Running llvm-kompile')
subprocess.run(
[
'llvm-kompile',
str(llvm_def_output),
str(target_llvmdt_path),
'c',
'-O2',
'--',
'-o',
target_llvm_path / 'interpreter',
],
check=True,
)

# Process Haskell definition
_LOGGER.info('Writing Haskell definition file')
hs_def_file = HASKELL_DEF_DIR / 'definition.kore'
_insert_rules_and_write(hs_def_file, rules, target_hs_path / 'definition.kore')

# Copy all files except definition.kore and binary from HASKELL_DEF_DIR to out/hs
_LOGGER.info('Copying other artefacts into HS output directory')
for file_path in HASKELL_DEF_DIR.iterdir():
if file_path.name != 'definition.kore' and file_path.name != 'haskellDefinition.bin':
if file_path.is_file():
shutil.copy2(file_path, target_hs_path / file_path.name)
elif file_path.is_dir():
shutil.copytree(file_path, target_hs_path / file_path.name, dirs_exist_ok=True)
return KMIR(target_hs_path, target_llvm_path, bug_report=bug_report)
else:

target_llvm_path = target_path / 'llvm'
target_llvmdt_path = target_llvm_path / 'dt'
_LOGGER.info(f'Creating directory {target_llvmdt_path}')
target_llvmdt_path.mkdir(parents=True, exist_ok=True)

# Process LLVM definition
_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_def_output = target_llvm_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, rules, llvm_def_output)

import subprocess

_LOGGER.info('Running llvm-kompile-matching')
subprocess.run(
['llvm-kompile-matching', str(llvm_def_output), 'qbaL', str(target_llvmdt_path), '1/2'], check=True
)
_LOGGER.info('Running llvm-kompile')
subprocess.run(
[
'llvm-kompile',
str(llvm_def_output),
str(target_llvmdt_path),
'main',
'-O2',
'--',
'-o',
target_llvm_path / 'interpreter',
],
check=True,
)
blacklist = ['definition.kore', 'interpreter', 'dt']
to_copy = [file.name for file in LLVM_DEF_DIR.iterdir() if file.name not in blacklist]
for file in to_copy:
_LOGGER.info(f'Copying file {file}')
shutil.copy2(LLVM_DEF_DIR / file, target_llvm_path / file)
return KMIR(target_llvm_path, None, bug_report=bug_report)
from .kompile import kompile_smir
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates an import loop (however only in local scope), that's why I didn't continue with factoring out the compilation code. If that's not a problem I have no objections.


kompiled_smir = kompile_smir(
smir_info=smir_info,
target_dir=target_dir,
bug_report=bug_report,
symbolic=symbolic,
)
return kompiled_smir.create_kmir(bug_report_file=bug_report)

class Symbols:
END_PROGRAM: Final = KApply('#EndProgram_KMIR-CONTROL-FLOW_KItem')
Expand All @@ -203,96 +87,6 @@ def kcfg_explore(self, label: str | None = None) -> Iterator[KCFGExplore]:
) as cts:
yield KCFGExplore(cts, kcfg_semantics=KMIRSemantics())

def functions(self, smir_info: SMIRInfo) -> dict[int, KInner]:
functions: dict[int, KInner] = {}

# Parse regular functions
for item_name, item in smir_info.items.items():
if not item_name in smir_info.function_symbols_reverse:
_LOGGER.warning(f'Item not found in SMIR: {item_name}')
continue
parsed_item = self.parser.parse_mir_json(item, 'MonoItem')
if not parsed_item:
raise ValueError(f'Could not parse MonoItemKind: {parsed_item}')
parsed_item_kinner, _ = parsed_item
assert isinstance(parsed_item_kinner, KApply) and parsed_item_kinner.label.name == 'monoItemWrapper'
# each item can have several entries in the function table for linked SMIR JSON
for ty in smir_info.function_symbols_reverse[item_name]:
functions[ty] = parsed_item_kinner.args[1]

# Add intrinsic functions
for ty, sym in smir_info.function_symbols.items():
if 'IntrinsicSym' in sym and ty not in functions:
functions[ty] = KApply(
'IntrinsicFunction',
[KApply('symbol(_)_LIB_Symbol_String', [token(sym['IntrinsicSym'])])],
)

return functions

def make_kore_rules(self, smir_info: SMIRInfo) -> list[str]: # generates kore syntax directly as string
equations = []

# kprint tool is too chatty
kprint_logger = logging.getLogger('pyk.ktool.kprint')
kprint_logger.setLevel(logging.WARNING)

for fty, kind in self.functions(smir_info).items():
equations.append(
self._mk_equation('lookupFunction', KApply('ty', (token(fty),)), 'Ty', kind, 'MonoItemKind')
)

types: set[KInner] = set()
for type in smir_info._smir['types']:
parse_result = self.parser.parse_mir_json(type, 'TypeMapping')
assert parse_result is not None
type_mapping, _ = parse_result
assert isinstance(type_mapping, KApply) and len(type_mapping.args) == 2
ty, tyinfo = type_mapping.args
if ty in types:
raise ValueError(f'Key collision in type map: {ty}')
types.add(ty)
equations.append(self._mk_equation('lookupTy', ty, 'Ty', tyinfo, 'TypeInfo'))

for alloc in smir_info._smir['allocs']:
alloc_id, value = self._decode_alloc(smir_info=smir_info, raw_alloc=alloc)
equations.append(self._mk_equation('lookupAlloc', alloc_id, 'AllocId', value, 'Evaluation'))

return equations

def _mk_equation(self, fun: str, arg: KInner, arg_sort: str, result: KInner, result_sort: str) -> str:
from pyk.kore.rule import FunctionRule
from pyk.kore.syntax import App, SortApp

arg_kore = self.kast_to_kore(arg, KSort(arg_sort))
fun_app = App('Lbl' + fun, (), (arg_kore,))
result_kore = self.kast_to_kore(result, KSort(result_sort))

assert isinstance(fun_app, App)
rule = FunctionRule(
lhs=fun_app,
rhs=result_kore,
req=None,
ens=None,
sort=SortApp('Sort' + result_sort),
arg_sorts=(SortApp('Sort' + arg_sort),),
anti_left=None,
priority=50,
uid='fubar',
label='fubaz',
)
return '\n'.join(
[
'',
'',
(
f'// {fun}({self.pretty_print(arg)})'
+ (f' => {self.pretty_print(result)}' if len(self.pretty_print(result)) < 1000 else '')
),
rule.to_axiom().text,
]
)

def make_call_config(
self,
smir_info: SMIRInfo,
Expand Down Expand Up @@ -327,22 +121,6 @@ def make_call_config(
config = self.definition.empty_config(KSort(sort))
return (subst.apply(config), constraints)

def _decode_alloc(self, smir_info: SMIRInfo, raw_alloc: Any) -> tuple[KInner, KInner]:
from .decoding import UnableToDecodeValue, decode_alloc_or_unable

alloc_id = raw_alloc['alloc_id']
alloc_info = smir_info.allocs[alloc_id]
value = decode_alloc_or_unable(alloc_info=alloc_info, types=smir_info.types)

match value:
case UnableToDecodeValue(msg):
_LOGGER.debug(f'Decoding failed: {msg}')
case _:
pass

alloc_id_term = KApply('allocId', intToken(alloc_id))
return alloc_id_term, value.to_kast()

def run_smir(self, smir_info: SMIRInfo, start_symbol: str = 'main', depth: int | None = None) -> Pattern:
smir_info = smir_info.reduce_to(start_symbol)
init_config, init_constraints = self.make_call_config(smir_info, start_symbol=start_symbol, init=True)
Expand Down
Loading