In [None]:
from dataclasses import dataclass
import operator
from kdrag.all import *
from typing import Self

@dataclass
class SymUnion(): # SymUnion[T]
    # invariant: all branches are disjoint?
    values : dict[smt.BoolRef, object]
    # values : dict[object, smt.BoolRef]
    # values : list[tuple[smt.BoolRef, object]] # SymUnion[Symunion] but symunion might not be hasable

    @classmethod
    def lift(cls, v : object):
        return cls({smt.BoolVal(True) : v})
    def split(self, c : smt.BoolRef) -> tuple['SymUnion', 'SymUnion']:
        true_branch = self.guard(c)
        false_branch = self.guard(smt.Not(c)) 
        return true_branch, false_branch
    def disjoint_union(self, other : 'SymUnion', check=True) -> 'SymUnion':
        if check:

        return SymUnion(self.values + other.values)
    def guard(self, c : smt.BoolRef): # assume?
        return SymUnion({smt.And(k, c) : v for k, v in self.values.items()}) # maybe compress here?
    def weaksimp(self):
        # weak simplify
        self.values = {k: v for k, v in self.values.items() if not smt.simplify(k).eq(smt.BoolVal(False))}
    def simplify(self): # prune prune_impossible
        # strong simplify
        new_values = {}
        for cond, val in self.values.items():
            s = smt.Solver()
            s.add(cond)
            res = s.check()
            if res == smt.sat:
                new_values[cond] = val
            elif res == smt.unsat:
                continue
            else:
                raise Exception("Unknown satisfiability")
        self.values = new_values
        return self
    def merge(self):
        # Merge branches with same values
        new_values : dict[object, smt.BoolRef] = {}
        for cond, val in self.values.items():
            c = new_values.get(val)
            if c is None:
                new_values[val] = cond
            else:
                new_values[val] = smt.Or(c, cond)
        self.values = {cond : v for v, cond in new_values.items()}
        return self
    def merge_sym(self : "SymUnion[smt.ExprRef]") -> 'SymUnion[smt.ExprRef]':
        # put all path ways into if then else inside solver
        acc = smt.FreshConst(self.values.sort())
        c = smt.BoolVal(False)
        for c1,v in self.values.items():
            acc = smt.If(c, v, acc)
            c = smt.Or(c, c1)
        return SymUnion({c : acc})
    def sym(self : Self) -> 'SymUnion[smt.ExprRef]':
        return self.map(smt._py2expr)
    def join(self): # monadic join. A symunion of symunions can collapse into a symunion
        return SymUnion({smt.And(c1,c2) : v2 for c1, v1 in self.values.items() for c2, v2 in v.values.items()})
    def is_empty(self):
        self.simplify()
        return len(self.values) == 0
    def map(self, f):
        return SymUnion({k: f(v) for k, v in self.values.items()})
    def map2(self, other, f):
        if isinstance(other, SymUnion):
            return SymUnion({smt.simplify(smt.And(k1,k2)) : f(v1, v2) for k1, v1 in self.values.items() for k2, v2 in other.values.items()})
        else:  
            return SymUnion({k1 : f(v1, other) for k1, v1 in self.values.items()}) # self.map(lambda x: f(x, other))
    def flatmap(self, f): # monadic bind >>=
        result = {}
        for k1, v1 in self.values.items():
            su2 : SymUnion = f(v1)
            for k2, v2 in su2.values.items():
                result[smt.And(k1, k2)] = v2
        return SymUnion(result)
    def __add__(self, other):
        return self.map2(other, operator.add)
    def __sub__(self, other):
        return self.map2(other, operator.sub)
    def __mul__(self, other):
        return self.map2(other, operator.mul)
    def __truediv__(self, other):
        return self.map2(other, operator.truediv)
    def __or__(self, other):
        return self.map2(other, operator.or_)
    def __and__(self, other):
        return self.map2(other, operator.and_)
    #def __call__(self, *args, **kwargs):
    #    return self.map(lambda v: v(*args, **kwargs)) # actuall, arguments might be symunion also
    # _reflect_expr
    @classmethod
    def reflect_expr(cls, e : smt.ExprRef, hyp=None) -> "SymUnion": # attempt to convert value to smt.ExprRef SymUnion[smt.ExprRef]
        s = smt.Solver()
        s.add(smt.FreshConst(e.sort()) == e)
        if hyp is not None:
            s.add(hyp)
        values = {}
        while True:
            res = s.check()
            if res == smt.sat:
                m = s.model()
                v = m.eval(e, model_completion=True)
                values[e == v] = v # hyp not needed. redundant
                s.add(e != m.eval(e, model_completion=True))
            elif res == smt.unsat:
                break
            else:
                raise Exception("Unknown satisfiability")
        return SymUnion(values)
    @classmethod
    def reflect_int(cls, e : smt.ArithRef, hyp=None):
        return self.reflect_expr.map(lambda x: x.as_long())
    # reflect_seq, reflect_dataclass
    
    @classmethod
    def reflect_bool(cls, c : smt.BoolRef) -> 'SymUnion':
        # expand bool? `Match` kind of...  This is x.If(x == True, x == False) or someting. "the trick" 
        return SymUnion({c : True, smt.Not(c) : False})
    @classmethod
    def Bool(cls, name : str) -> 'SymUnion':
        #return SymUnion.lift(smt.Bool(name)) ?
        return SymUnion.reflect_bool(smt.Bool(name))
    # reflect bitvec?
    # def reify(expr : smt.ExprRef) -> 'SymUnion': # attempt to convert smt.ExprRef to SymUnion By getting all model values
    def If(self, then_branch, else_branch):
        # if then and else are ExprRef can use smt.If? More efficient?
        return self.map(lambda c: then_branch if c else else_branch)
    def If(self, then_branch, else_branch):
        self.merge()
        result = {}
        for cond, val in self.values.items():
            if val:
                for t_cond, t_val in then_branch.values.items():
                    c = smt.simplify(smt.And(cond, t_cond))
                    if c in result:
                        assert result[c] == t_val
                    else:
                        result[c] = t_val
            else:
                for e_cond, e_val in else_branch.values.items():
                    c = smt.simplify(smt.And(cond, e_cond))
                    if c in result:
                        assert result[c] == e_val
                    else:
                        result[c] = e_val
        return SymUnion(result)
    def eq(): # syntactic equality of symunions
    def verify(self):
        # verify(a == b)  well that's kind of assuming they worlds aren't disjoint. This is kleene equality?
        for c1, v1 in self.values.items():
            kd.prove(smt.Imples(c1, v1))
    def assume(self, c):
        self.guard(c)
    def solve(self):
    def reflect_cond(self): # kind of this is just "True" everywhere...
        return SymUnion({c: c for c in self.values.keys()})
    @classmethod
    def flip(cls):
        v = smt.FreshBool()
        return SymUnion({v: True, smt.Not(v): False})


    @staticmethod
    def of_expr(expr: smt.ExprRef) -> "SymUnion[smt.ExprRef]":
        """
        >>> y = smt.Int("y")
        >>> SymUnion.of_expr(y + 1, ctx=smt.And(y > 0, y < 3)).map(expr2py)
        SymUnion(values=((And(And(y > 0, y < 3), 2 == y + 1), 2), (And(And(y > 0, y < 3), 3 == y + 1), 3)))
        """
        return SymUnion.wrap(expr).valueize()


    """
    but isn't .map(lambda x: match x: ...) better? More featureful, and lazier in evaluation.
    def switch(self, cases : dict[object, 'SymUnion']):
        result = {}
        for cond, val in self.values.items():
            case_su = cases.get(val)
            if case_su is None:
                raise Exception(f"No case for value {val}")
            for c_cond, c_val in case_su.values.items():
                c = smt.simplify(smt.And(cond, c_cond))
                if c in result:
                    assert result[c] == c_val
                else:
                    result[c] = c_val
    """
        
        
    

SymUnion.lift(lambda x: x)
SymUnion.lift(3) + SymUnion.lift(4)

b = SymUnion.Bool("b")
b.If(SymUnion.lift(10), SymUnion.lift(20)).map(lambda x: x + 1)

x = smt.Int("x")
SymUnion.reflect(x, hyp=smt.And(x >= 1, x <= 3)).map(lambda v: v.as_long())
    

SymUnion(values={1 == x: 1, 2 == x: 2, 3 == x: 3})

keep it in a bdd like structure? Sort conditions.
Well sorting them by construction helps.
smt.simplify


What is the minimum macro play I could do o get more orindary python code to work.
Just a visitor on If nodes
if 
else

if isisntance(x, SymbolicUnion)
    x = SymUnion.empty
    for c,v in x.values:
        x |= whatev
    merge()
else:
    x = whatev

ugh


I need a solution to mutation

SymCell




In [None]:


def eval_()

In [None]:
from functools import singledispatch
kd.notation.SortDispatch

# a nested chain of SortDispatch. So it always disambiguates on the first sort.
# if I don't have coercions, it's all kind of straightforward?

class MultiDispatch:
    f : SortDispatch
    def get
    def register(self, sorts, func):
        d = self.f
        for sort in sorts:
            d1 = d.get(sort)
            if d1 is None:
                d1 = SortDispatch()
                d[sort] = d1
            d = d1


Put a map around stuff that can't be lifted. if we make it curriable. Maybe memozie results

Avoid the double loop in map2
We can sometimes key, we will probably often be map2ing two symunions that have the same contexts in them.

Yea, the trick, pattern matching, learning stuff inside branches.
Tries and telescopes


Two notions of equality, syntactic .eq and == and semantics
Python vs z3


The python ADT


In [7]:
import functools


def foo(x,*y):
    if y:
        return x + sum(y)
    else:
        return functools.partial(foo,x)

foo(3)(4)
foo(3,4)
foo(3)

functools.partial(<function foo at 0x794ba1be2660>, 3)

If I want @x.map to work. No it'll work without partial




In [None]:
class Biz():
    def map(self, f):
        print("mapped")
        return f(self)

y = Biz()

@y.map
def foo(z): # the result is named after the funciton. It's weird. Yea. Ok. This is not good
    print("in foo")
    return z
# foo = y.map(lambda : ...)

def case7(x):
    match x:
        case 0:
            return "zero"
        case 1:
            return "one"
        case _:
            return "many"
foo = x.map(case7)
# make flatmap just also work silently like map if arguments are not symunion?
# kind of a coercion to symunion.

foo

mapped
in foo


<__main__.Biz at 0x794ba08b6660>

In [None]:
@x.map
def _(x):
    if x:
        return 1
    else:
        return 2
@x.flatmap
def _(x):
    if x:
        return SymUnion.lift(1)
    else:
        return SymUnion.lift(2)
x.map(lambda x: 1 if x else 2)


In [None]:
from dataclasses import dataclass
import kdrag.smt as smt
from typing import Self, Callable
@dataclass(frozen=True)
class SymUnion[T]:
    values : tuple[tuple[smt.BoolRef, T], ...]
    @classmethod
    def from_value(cls, v : T) -> 'Self[T]':
        return cls(((smt.BoolVal(True), v),))
    def map[U](self : 'SymUnion[T]', f : Callable[[T], U]) -> 'SymUnion[U]':
        return SymUnion(tuple((k, f(v)) for k, v in self.values))
    

NameError: name 'Callable' is not defined

In [None]:



# we need something lazy
# we can't overload python if

# no. this doesn't really work. I want branches.
def guard(x):
    def res(f):
        if isinstance(x, SymUnion):
            f()


@guard()
def _(x):
     #do stuff


Caleb said
"I mean you can construct a Python object like this:

class MyIf:
	def __init__(self, cond, tb, fb):
        self.cond = cond
		self.tb = tb
		self.fb = fb

MyIf(claripy.ast.BoolS("cond"), "I'm in world A!", "I'm in world B!")
"
Which is a good point. That way kind of head bdds


Hmm. SymUnion is super monadic. It's the list monad with extra side conditions.
Maybe I should look at grisette to see what it looks like
https://hackage.haskell.org/package/grisette
https://digital.lib.washington.edu/researchworks/items/61a8d175-cd6a-40cf-9abc-6b87b79f02d7 grisette thesis

hierasynth and tensoright. Hmm. Ok

Ordered guards... This smells bdd like? BDD(T). That's a fun idea. You could abstract atoms and then if you have a contextual rewriting system... well but that precludes variable reordering. Hmm. Well if you have a global normalization system. But that is almost not interesting? 

Pythette

I mean, I could just copy the leancall design.
Only trasnfer values back and forth

One could do normalization by evaluation. Symbolic values and stuck terms. It's the same idea.
|- {True, False, smt.Bool("a")}

def Bool(ctx : SymUnion) -> set[bool | smt.ExprRef]:
def Bool(ctx : SymUnion) -> SymUnion:
def Bool()


[smt.Bool("a"), smt.Bool("b")]


A fixed "parser"

@singledispatch
def py2expr(x : object) -> smt.ExprRef:
	if hasattr(x, "py2expr"):
	else:
		return smt._py2expr(x)

@method(py2expr)
def _(x : tuple):
	return kd.tuple_(map(py2expr, x))

@singledispatch
def py2expr()


expr2py = kd.SortDispatch()
coerce = kd.SortDispatch(default = lambda x,S: S.coerce(x))
x.coerce(S, )


Hmm. kdrag.refelct.
I could have verify and synth calls in there...
Allow more complex stuff via symbolic union...

Maybe just the pass to reinterpet while and if as working over symunion
And lift recursive calls to something.
ast walker, unpare



symbolic union and rosette

In ordinary symboic execution, there is a path condition that records what preconditions are needed on the input to take the path one has taken.
The Symbolic union is kind of a struct of arrays or reorganization of these path conditions.
Turning an algorithm or process into a data structure can be a very powerful idea.


Would converting my riscv symbolic executor to this form clean it up?

https://docs.racket-lang.org/rosette-guide/sec_value-reflection.html#(part._sec~3asymbolic-unions)

thunks for recursion

`SymUnion[smt.ExprRef]`

There's almost no point to storing value : dict[smt.BoolRef, object]. We can never really index nicely into the boolexpr
In fact, the opposite is true. It is more useful to store value : dict[object, smt.BoolRef]

Can I get a symexec for free from python interpreter?
Mayyyyybe. By rewriting if then else?
Using knuckeldragger reify?


So roulette used a bdd of probabilities and also a bdd of values?
Or bdd of probablisties and Symunion of values?


Caleb says claripy ITE works like this.
hm claripy has the abstract interp...


https://sites.google.com/cs.washington.edu/cse507-25au
https://gitlab.cs.washington.edu/cse507-25au/cse507-25au-public/-/tree/main?ref_type=heads

This bitvector example is gnarly




In [None]:
%%file /tmp/bvlang.scm
#lang rosette

; https://gitlab.cs.washington.edu/cse507-25au/cse507-25au-public/-/blob/main/lecture/lec02/bvlang.rkt?ref_type=heads
(require
  (prefix-in $ (only-in rosette bveq bvslt bvsgt bvsle bvsge bvult bvugt bvule bvuge))
  rosette/lib/angelic rosette/lib/match)
; ----- BV semantics -----;
; BV comparison operators return 1/0 instead of #t/#f.
; The language is similar to the one defined in this paper:
; https://dl.acm.org/citation.cfm?id=1993506
(define register? integer?)
(define int32? (bitvector 32))
(define (int32 c) (bv c int32?))
(define-syntax-rule (define-comparators [id op] ...)
  (begin (define (id x y) (if (op x y) (int32 1) (int32 0))) ...))
(define-comparators
  [bveq $bveq]
  [bvslt $bvslt]
  [bvult $bvult]
  [bvsle $bvsle]
  [bvule $bvule]
  [bvsgt $bvsgt]
  [bvugt $bvugt]
  [bvsge $bvsge]
  [bvuge $bvuge])
(define ops (list bveq bvslt bvult bvsle bvule bvsgt bvugt bvsge
                  bvnot bvand bvor bvxor bvshl bvlshr bvashr
                  bvneg bvadd bvsub bvmul bvsdiv bvudiv bvsrem bvurem bvsmod))
(define optable (for/list ([op ops]) (cons (object-name op) op)))
; Returns the procedure corresponding to the given opcode.
(define (lookup op)
  (cdr (assoc op optable)))
; Global registers.
(define memory
  (let ([m (vector)])
    (case-lambda [() m]
                 [(size) (set! m (make-vector size #f))])))
; Returns the contents of the register idx if idx is a register,
; otherwise returns idx itself.
(define (load idx)
  (if (register? idx)
      (vector-ref (memory) idx)
      idx))
; Stores val in the register idx.
(define (store idx val)
  (vector-set! (memory) idx val))
; Returns the contents of the last register.
(define (last)
  (sub1 (vector-length (memory))))
; Creates the registers for the given program and input.
(define (make-registers prog inputs)
  (memory (+ (length prog) (length inputs)))
  (for ([(in idx) (in-indexed inputs)])
    (store idx in)))
; The BV interpreter.
(define (interpret prog inputs)
  (make-registers prog inputs)
  (for ([stmt prog])
    (match stmt
      [(list out opcode in ...)
       (define op (lookup opcode))
       (define args (map load in))
       (store out (apply op args))]))
  (load (last)))
; The BV verifier.
(define (ver impl spec)
  (define-symbolic* in int32? #:length (procedure-arity spec))
  (define cex
    (verify
     (assert (equal? (interpret (prog-body impl) in) (apply spec in)))))
  (if (sat? cex)
      (evaluate in cex)
      cex))
; The BV synthesizer.
(define (syn impl spec)
  (define-symbolic* in int32? #:length (procedure-arity spec))
  (define sol
    (synthesize
     #:forall in
     #:guarantee (assert (equal? (interpret (prog-body impl) in) (apply spec in)))))
  (if (sat? sol)
      (evaluate impl sol)
      sol))
; ----- BV syntax -----;
(define (make-sketch args len ops)
  (define-values (unop binop)
    (partition (or/c 'bvneg 'bvnot) ops))
  (for/list ([r len])
    (define-symbolic* c0 c1 int32?)
    (define out (+ r args))
    (define ins (build-list out identity))
    (define i0 (apply choose* c0 ins))
    (define i1 (apply choose* c1 ins))
    (define inst1 (and (not (null? unop)) `(,(apply choose* unop) ,i0)))
    (define inst2 (and (not (null? binop)) `(,(apply choose* binop) ,i0 ,i1)))
    (cons out (apply choose* (filter identity (list inst1 inst2))))))
(struct prog (name args body) #:transparent)
; The def macro turns a list of BV instructions, or a BV #:sketch specification,
; into a prog that invokes the interpreter on the provided inputs.
(define-syntax def
  (syntax-rules ()
    [(_ id (idx ...) (out op in ...) ...)
     (define id
       (prog 'id '(idx ...) '((out op in ...) ...)))]
    [(_ id (idx ...) #:sketch len (op ...))
     (define id
       (prog 'id '(idx ...) (make-sketch (length '(idx ...)) len '(op ...))))]))
; ----- BV demo -----;
(define (bvmax0 x y)
  (if (equal? (bvsge x y) (int32 1)) x y))
(def bvmax1 (0 1)
  (2 bvsge 0 1)
  (3 bvneg 2)
  (4 bvxor 0 2)
  (5 bvand 3 4)
  (6 bvxor 1 5))
(def bvmax2 (0 1)
  #:sketch 5 (bvsge bvneg bvxor bvand))
; Interaction script: comment out the following two
; lines if you don't have boolector 3.2.1 installed.
; Z3 will work but it's slower.
(require rosette/solver/smt/boolector)
(current-solver (boolector))
(current-bitwidth 4)
(time (ver bvmax1 bvmax0))
(time (syn bvmax2 bvmax0))
; z3:
; cpu time: 61 real time: 480 gc time: 0
; (list (bv #x00010000 32) (bv #x00000000 32))
; cpu time: 13002 real time: 58496 gc time: 995
; (prog 'bvmax2 '(0 1) '((2 bvsge 0 1) (3 bvneg 2) (4 bvxor 0 1) (5 bvand 3 4) (6 bvxor 1 5)))
; boolector 3.2.1 with CaDiCal:
; cpu time: 46 real time: 325 gc time: 0
; (list (bv #xffffffff 32) (bv #xffffffff 32))
; cpu time: 1106 real time: 5611 gc time: 48
; (prog 'bvmax2 '(0 1) '((2 bvxor 1 0) (3 bvsge 0 1) (4 bvneg 3) (5 bvand 2 4) (6 bvxor 1 5)))

In [None]:
from enum import Enum
class OpCode(Enum):
    bvadd = 1
    bvsub = 2
    bvxor = 3
    bvsge = 4

@dataclass
class Op():
    out : int
    opcode : OpCode
    args : list[int]

type prog = list[Op]

def interp(prog, state):
    for op in prog:
        match op.opcode:# wrong
            case OpCode.bvadd:
                state[op.out] = state[op.args[0]] + state[op.args[]]
            case OpCode.bvsub:
                state[op.out] = state[op.args[0]] - state[op.args]
            case OpCode.bvxor:
                state[op.out] = state[op.args[0]] ^ state[op.args]
            case OpCode.bvsge:
                state[op.out] = state[op.args[0]] >= state[op.args]             
    return state


