In [17]:
from dataclasses import dataclass
import operator
from kdrag.all import *
@dataclass
class SymUnion(): # SymUnion[T]
    # invariant: all branches are disjoint?
    values : dict[smt.BoolRef, object]

    @classmethod
    def lift(cls, v : object):
        return cls({smt.BoolVal(True) : v})
    def split(self, c : smt.BoolRef) -> tuple['SymUnion', 'SymUnion']:
        true_branch = self.guard(c)
        false_branch = self.guard(smt.Not(c)) 
        return true_branch, false_branch
    def guard(self, c : smt.BoolRef):
        return SymUnion({smt.And(k, c) : v for k, v in self.values.items()}) # maybe compress here?
    def weaksimp(self):
        # weak simplify
        self.values = {k: v for k, v in self.values.items() if not smt.simplify(k).eq(smt.BoolVal(False))}
    def simplify(self):
        # strong simplify
        new_values = {}
        for cond, val in self.values.items():
            s = smt.Solver()
            s.add(cond)
            res = s.check()
            if res == smt.sat:
                new_values[cond] = val
            elif res == smt.unsat:
                continue
            else:
                raise Exception("Unknown satisfiability")
        self.values = new_values
        return self
    def merge(self):
        # Merge branches with same values
        new_values : dict[object, smt.BoolRef] = {}
        for cond, val in self.values.items():
            c = new_values.get(val)
            if c is None:
                new_values[val] = cond
            else:
                new_values[val] = smt.Or(c, cond)
        self.values = {cond : v for v, cond in new_values.items()}
        return self
    def is_empty(self):
        self.simplify()
        return len(self.values) == 0
    def map(self, f):
        return SymUnion({k: f(v) for k, v in self.values.items()})
    def map2(self, other, f):
        if isinstance(other, SymUnion):
            return SymUnion({smt.simplify(smt.And(k1,k2)) : f(v1, v2) for k1, v1 in self.values.items() for k2, v2 in other.values.items()})
        else:  
            return SymUnion({k1 : f(v1, other) for k1, v1 in self.values.items()})
    def flatmap(self, f):
        result = {}
        for k1, v1 in self.values.items():
            su2 : SymUnion = f(v1)
            for k2, v2 in su2.values.items():
                result[smt.And(k1, k2)] = v2
        return SymUnion(result)
    def __add__(self, other):
        return self.map2(other, operator.add)
    def __sub__(self, other):
        return self.map2(other, operator.sub)
    def __mul__(self, other):
        return self.map2(other, operator.mul)
    def __truediv__(self, other):
        return self.map2(other, operator.truediv)
    def __or__(self, other):
        return self.map2(other, operator.or_)
    def __and__(self, other):
        return self.map2(other, operator.and_)
    #def __call__(self, *args, **kwargs):
    #    return self.map(lambda v: v(*args, **kwargs)) # actuall, arguments might be symunion also
    @classmethod
    def reflect(self, e : smt.ExprRef, hyp=None) -> "SymUnion": # attempt to convert value to smt.ExprRef
        s = smt.Solver()
        s.add(smt.FreshConst(e.sort()) == e)
        if hyp is not None:
            s.add(hyp)
        values = {}
        while True:
            res = s.check()
            if res == smt.sat:
                m = s.model()
                v = m.eval(e, model_completion=True)
                values[e == v] = v
                s.add(e != m.eval(e, model_completion=True))
            elif res == smt.unsat:
                break
            else:
                raise Exception("Unknown satisfiability")
        return SymUnion(values)
    @classmethod
    def reflect_bool(cls, c : smt.BoolRef) -> 'SymUnion':
        return SymUnion({c : True, smt.Not(c) : False})
    @classmethod
    def Bool(cls, name : str) -> 'SymUnion':
        return SymUnion.reflect_bool(smt.Bool(name))
    # reflect bitvec?
    # def reify(expr : smt.ExprRef) -> 'SymUnion': # attempt to convert smt.ExprRef to SymUnion By getting all model values
    def If(self, then_branch, else_branch):
        self.merge()
        result = {}
        for cond, val in self.values.items():
            if val:
                for t_cond, t_val in then_branch.values.items():
                    c = smt.simplify(smt.And(cond, t_cond))
                    if c in result:
                        assert result[c] == t_val
                    else:
                        result[c] = t_val
            else:
                for e_cond, e_val in else_branch.values.items():
                    c = smt.simplify(smt.And(cond, e_cond))
                    if c in result:
                        assert result[c] == e_val
                    else:
                        result[c] = e_val
        return SymUnion(result)
        
        
    

SymUnion.lift(lambda x: x)
SymUnion.lift(3) + SymUnion.lift(4)

b = SymUnion.Bool("b")
b.If(SymUnion.lift(10), SymUnion.lift(20)).map(lambda x: x + 1)

x = smt.Int("x")
SymUnion.reflect(x, hyp=smt.And(x >= 1, x <= 3)).map(lambda v: v.as_long())
    

SymUnion(values={1 == x: 1, 2 == x: 2, 3 == x: 3})

symbolic union and rosette

In ordinary symboic execution, there is a path condition that records what preconditions are needed on the input to take the path one has taken.
The Symbolic union is kind of a struct of arrays or reorganization of these path conditions.
Turning an algorithm or process into a data structure can be a very powerful idea.

