In [1]:
import sys
from pathlib import Path

# Add the parent directory to the Python path
sys.path.append(str(Path().resolve().parent))

from spytial import diagram
from spytial.annotations import orientation, attribute, hideAtom, atomColor, group, flag


In [2]:
"""
Pure‑Python Reduced Ordered Binary Decision Diagrams (ROBDDs)
================================================================

A tiny, dependency‑free BDD manager using **dataclasses** and a
**relatively naive** implementation. It favors clarity over speed.

Features
--------
- Canonical reduced, ordered BDDs (ROBDDs)
- Unique table + memoized ITE
- Boolean ops via Python operators: `&`, `|`, `^`, `~`, `>>` (implies)
- Restriction, substitution (functional composition), quantification
- Evaluation, model counting, one satisfying assignment
- DOT export (for Graphviz)

Design
------
- Node ids are integers: 0 = `FALSE`, 1 = `TRUE`, ≥2 = internal `(var, lo, hi)`.
- Variable order is append‑only; you may add variables, not reorder.
- The code is intentionally straightforward, relying on Python dicts.

Example
-------
>>> mgr = BDD(["x","y","z"])   # variable order x < y < z
>>> x,y,z = mgr.vars("x","y","z")
>>> f = (x & ~y) | z
>>> mgr.evaluate(f, {"x": True, "y": True, "z": False})
False
>>> mgr.sat_count(f)
6
>>> mgr.pick_one_sat(f)
{'x': False, 'y': False, 'z': False}  # one possible model

"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Tuple, Optional, Iterable, Callable, List, Set, Any

# ----- Core types & constants -----
NodeId = int
Var = str
FALSE: NodeId = 0
TRUE: NodeId = 1


@dataclass(frozen=True)
class Node:
    var: Var
    lo: NodeId
    hi: NodeId


@dataclass(frozen=True)
class BDDRef:
    """Lightweight reference to a node managed by a BDD."""
    mgr: "BDD"
    id: NodeId

    # Pretty printing
    def __repr__(self) -> str:
        if self.id == TRUE: return "BDDRef(TRUE)"
        if self.id == FALSE: return "BDDRef(FALSE)"
        n = self.mgr._nodes[self.id]
        return f"BDDRef({n.var}?{n.lo}:{n.hi})"

    # Boolean operators -> delegate to manager
    def __invert__(self) -> "BDDRef":
        return BDDRef(self.mgr, self.mgr.neg(self.id))

    def __and__(self, other: "BDDRef") -> "BDDRef":
        self._same_mgr(other)
        return BDDRef(self.mgr, self.mgr.apply(lambda a,b: a and b, self.id, other.id))

    def __or__(self, other: "BDDRef") -> "BDDRef":
        self._same_mgr(other)
        return BDDRef(self.mgr, self.mgr.apply(lambda a,b: a or b, self.id, other.id))

    def __xor__(self, other: "BDDRef") -> "BDDRef":
        self._same_mgr(other)
        return BDDRef(self.mgr, self.mgr.apply(lambda a,b: (a and not b) or (b and not a), self.id, other.id))

    def __rshift__(self, other: "BDDRef") -> "BDDRef":  # implies
        self._same_mgr(other)
        return (~self) | other

    def iff(self, other: "BDDRef") -> "BDDRef":
        return ~(self ^ other)

    def restrict(self, assignment: Dict[Var, bool]) -> "BDDRef":
        return BDDRef(self.mgr, self.mgr.restrict(self.id, assignment))

    def exists(self, vars_to_eliminate: Iterable[Var]) -> "BDDRef":
        return BDDRef(self.mgr, self.mgr.exists(self.id, set(vars_to_eliminate)))

    def forall(self, vars_to_eliminate: Iterable[Var]) -> "BDDRef":
        return BDDRef(self.mgr, self.mgr.forall(self.id, set(vars_to_eliminate)))

    def compose(self, substitution: Dict[Var, "BDDRef"]) -> "BDDRef":
        sub = {v:r.id for v,r in substitution.items()}
        return BDDRef(self.mgr, self.mgr.compose(self.id, sub))

    def evaluate(self, assignment: Dict[Var, bool]) -> bool:
        return self.mgr.evaluate(self.id, assignment)

    def sat_count(self) -> int:
        return self.mgr.sat_count(self.id)

    def pick_one_sat(self) -> Optional[Dict[Var, bool]]:
        return self.mgr.pick_one_sat(self.id)

    def to_dot(self, name: str = "F") -> str:
        return self.mgr.to_dot(self.id, name)

    def _same_mgr(self, other: "BDDRef") -> None:
        if self.mgr is not other.mgr:
            raise ValueError("Cannot combine BDDs from different managers")


class BDD:
    """A simple ROBDD manager with an append‑only variable order."""
    def __init__(self, ordering: Optional[Iterable[Var]] = None):
        self._nodes: Dict[NodeId, Node] = {}            # id -> Node
        self._unique: Dict[Tuple[Var, NodeId, NodeId], NodeId] = {}
        self._var2level: Dict[Var, int] = {}
        self._level2var: List[Var] = []
        self._next_id: int = 2

        # Caches (naive: just python dicts)
        self._ite_cache: Dict[Tuple[NodeId, NodeId, NodeId], NodeId] = {}
        self._restrict_cache: Dict[Tuple[NodeId, Tuple[Tuple[Var,bool],...]], NodeId] = {}
        self._compose_cache: Dict[Tuple[NodeId, Tuple[Tuple[Var,NodeId],...]], NodeId] = {}
        self._satcount_cache: Dict[Tuple[NodeId,int], int] = {}

        if ordering:
            for v in ordering:
                self.add_var(v)

    # ---- Variables ----
    @property
    def num_vars(self) -> int:
        return len(self._level2var)

    def add_var(self, v: Var) -> None:
        if v in self._var2level:
            return
        self._var2level[v] = len(self._level2var)
        self._level2var.append(v)

    def vars(self, *names: Var) -> Tuple[BDDRef, ...]:
        return tuple(self.var(n) for n in names)

    def var(self, name: Var) -> BDDRef:
        if name not in self._var2level:
            self.add_var(name)
        return BDDRef(self, self._mk(name, FALSE, TRUE))

    # ---- Constructors ----
    def _mk(self, var: Var, lo: NodeId, hi: NodeId) -> NodeId:
        if lo == hi:
            return lo
        key = (var, lo, hi)
        n = self._unique.get(key)
        if n is not None:
            return n
        nid = self._next_id
        self._next_id += 1
        self._unique[key] = nid
        self._nodes[nid] = Node(var, lo, hi)
        return nid

    # ---- Helpers ----
    def _level(self, var: Var) -> int:
        try:
            return self._var2level[var]
        except KeyError:
            raise KeyError(f"Unknown variable '{var}'. Add it first.")

    def _top_var(self, u: NodeId) -> Optional[Var]:
        if u in (FALSE, TRUE):
            return None
        return self._nodes[u].var

    def _split(self, u: NodeId, v: Var) -> Tuple[NodeId, NodeId]:
        """Return (u0,u1) wrt variable v: if top(u)==v, return its (lo,hi), else (u,u)."""
        if u in (FALSE, TRUE):
            return (u, u)
        n = self._nodes[u]
        if n.var == v:
            return (n.lo, n.hi)
        return (u, u)

    # ---- ITE and logical ops ----
    def ite(self, i: NodeId, t: NodeId, e: NodeId) -> NodeId:
        key = (i, t, e)
        if key in self._ite_cache:
            return self._ite_cache[key]

        # Terminal reductions
        if i == TRUE:
            res = t
        elif i == FALSE:
            res = e
        elif t == e:
            res = t
        elif t == TRUE and e == FALSE:
            res = i
        elif t == FALSE and e == TRUE:
            res = self.neg(i)
        else:
            # determine top variable among i,t,e
            tops: List[Tuple[int, Optional[Var], NodeId]] = []
            for u in (i, t, e):
                v = self._top_var(u)
                tops.append((self._level(v) if v is not None else self.num_vars + 1, v, u))
            v = min(tops, key=lambda x: x[0])[1]
            assert v is not None
            i0, i1 = self._split(i, v)
            t0, t1 = self._split(t, v)
            e0, e1 = self._split(e, v)
            lo = self.ite(i0, t0, e0)
            hi = self.ite(i1, t1, e1)
            res = self._mk(v, lo, hi)

        self._ite_cache[key] = res
        return res

    def neg(self, u: NodeId) -> NodeId:
        return self.ite(u, FALSE, TRUE)

    def apply(self, op: Callable[[bool, bool], bool], a: NodeId, b: NodeId) -> NodeId:
        """Naively implement binary op using ITE with the op's truth table."""
        # Compute constants once
        tt = TRUE if op(True, True) else FALSE
        tf = TRUE if op(True, False) else FALSE
        ft = TRUE if op(False, True) else FALSE
        ff = TRUE if op(False, False) else FALSE
        # op(a,b) == ITE(a, ITE(b, tt, tf), ITE(b, ft, ff))
        return self.ite(a, self.ite(b, tt, tf), self.ite(b, ft, ff))

    # ---- Restriction, quantification, composition ----
    def restrict(self, u: NodeId, assignment: Dict[Var, bool]) -> NodeId:
        key = (u, tuple(sorted(assignment.items())))
        if key in self._restrict_cache:
            return self._restrict_cache[key]
        if u in (FALSE, TRUE):
            res = u
        else:
            n = self._nodes[u]
            if n.var in assignment:
                res = self.restrict(n.hi if assignment[n.var] else n.lo, assignment)
            else:
                lo = self.restrict(n.lo, assignment)
                hi = self.restrict(n.hi, assignment)
                res = self._mk(n.var, lo, hi)
        self._restrict_cache[key] = res
        return res

    def exists(self, u: NodeId, vars_set: Set[Var]) -> NodeId:
        if u in (FALSE, TRUE) or not vars_set:
            return u
        n = self._nodes.get(u)
        if n is None:
            return u
        if n.var in vars_set:
            # ∃x. f = (∃rest. f0) OR (∃rest. f1)
            return self.apply(lambda a,b: a or b,
                              self.exists(n.lo, vars_set),
                              self.exists(n.hi, vars_set))
        else:
            return self._mk(n.var, self.exists(n.lo, vars_set), self.exists(n.hi, vars_set))

    def forall(self, u: NodeId, vars_set: Set[Var]) -> NodeId:
        if u in (FALSE, TRUE) or not vars_set:
            return u
        n = self._nodes.get(u)
        if n is None:
            return u
        if n.var in vars_set:
            # ∀x. f = (∀rest. f0) AND (∀rest. f1)
            return self.apply(lambda a,b: a and b,
                              self.forall(n.lo, vars_set),
                              self.forall(n.hi, vars_set))
        else:
            return self._mk(n.var, self.forall(n.lo, vars_set), self.forall(n.hi, vars_set))

    def compose(self, u: NodeId, subst: Dict[Var, NodeId]) -> NodeId:
        key = (u, tuple(sorted(subst.items())))
        if key in self._compose_cache:
            return self._compose_cache[key]
        if u in (FALSE, TRUE):
            res = u
        else:
            n = self._nodes[u]
            if n.var in subst:
                res = self.ite(subst[n.var], self.compose(n.hi, subst), self.compose(n.lo, subst))
            else:
                res = self._mk(n.var, self.compose(n.lo, subst), self.compose(n.hi, subst))
        self._compose_cache[key] = res
        return res

    # ---- Evaluation & models ----
    def evaluate(self, u: NodeId, assignment: Dict[Var, bool]) -> bool:
        """Evaluate f under a (possibly partial) assignment (defaults False)."""
        while u not in (FALSE, TRUE):
            n = self._nodes[u]
            bit = assignment.get(n.var, False)
            u = n.hi if bit else n.lo
        return u == TRUE

    def _sat_count_from(self, u: NodeId, level: int) -> int:
        key = (u, level)
        if key in self._satcount_cache:
            return self._satcount_cache[key]
        if u == FALSE:
            res = 0
        elif u == TRUE:
            # free vars from this level to the end are unconstrained
            res = 1 << (self.num_vars - level)
        else:
            n = self._nodes[u]
            vlevel = self._level(n.var)
            gap = vlevel - level
            # account for skipped variables by multiplying by 2^gap
            factor = 1 << gap
            res = factor * (self._sat_count_from(n.lo, vlevel + 1) +
                            self._sat_count_from(n.hi, vlevel + 1))
        self._satcount_cache[key] = res
        return res

    def sat_count(self, u: NodeId) -> int:
        return self._sat_count_from(u, 0)

    def pick_one_sat(self, u: NodeId) -> Optional[Dict[Var, bool]]:
        if u == FALSE:
            return None
        model: Dict[Var, bool] = {}
        level = 0
        while u not in (FALSE, TRUE):
            n = self._nodes[u]
            vlevel = self._level(n.var)
            # Fill skipped vars with False
            while level < vlevel:
                model[self._level2var[level]] = False
                level += 1
            # Prefer hi branch when satisfiable
            if self._sat_count_from(n.hi, vlevel + 1) > 0:
                model[n.var] = True
                u = n.hi
            else:
                model[n.var] = False
                u = n.lo
            level = vlevel + 1
        # Fill trailing vars with False
        while level < self.num_vars:
            model[self._level2var[level]] = False
            level += 1
        return model if u == TRUE else None

    # ---- DOT export ----
    def to_dot(self, u: NodeId, name: str = "F") -> str:
        lines: List[str] = ["digraph BDD {", "  rankdir=TB;", "  node [shape=circle];"]
        lines.append(f"  label=\"{name}\";")
        seen: Set[NodeId] = set()
        def walk(x: NodeId) -> None:
            if x in seen: return
            seen.add(x)
            if x == TRUE:
                lines.append("  T [shape=box,label=1];")
            elif x == FALSE:
                lines.append("  F [shape=box,label=0];")
            else:
                n = self._nodes[x]
                lines.append(f"  {x} [label=\"{n.var}\"];")
                # dashed 0-edge, solid 1-edge (classic convention)
                for child, style, lbl in ((n.lo, "dashed", 0), (n.hi, "solid", 1)):
                    tgt = "T" if child == TRUE else "F" if child == FALSE else str(child)
                    lines.append(f"  {x} -> {tgt} [style={style},label=\"{lbl}\"];\n")
                walk(n.lo); walk(n.hi)
        walk(u)
        lines.append("}")
        return "\n".join(lines)

    # ---- Convenience ----
    def const(self, value: bool) -> BDDRef:
        return BDDRef(self, TRUE if value else FALSE)

    # Evaluation helpers for users
    def is_true(self, r: BDDRef) -> bool:  # canonical equality check
        return r.id == TRUE
    def is_false(self, r: BDDRef) -> bool:
        return r.id == FALSE


# ---- Simple manual test ----
if __name__ == "__main__":
    mgr = BDD(["x","y","z"]) 
    x,y,z = mgr.vars("x","y","z")
    f = (x & ~y) | z
    print("sat_count(f) =", f.sat_count())
    print("evaluate(f,{x:1,y:1,z:0}) =", f.evaluate({"x":True,"y":True,"z":False}))
    print("one model:", f.pick_one_sat())
    print(mgr.to_dot(f.id))


RecursionError: maximum recursion depth exceeded