In [None]:
%pip install -q spytial-diagramming
from spytial import *
from spytial.annotations import *
from typing import Dict, Tuple, Optional, Iterable, List
from dataclasses import dataclass

While CLRS does not implement BDDs, we include this as a bonus data structure that shows off some of SpyTial's capabilities.



# BDD Code

In [None]:
def exclude_terminals(v):
    return f"(@num:({v}.nid) != 0) and (@num:({v}.nid) != 1)"

# This applies ALL annotations to a BDD node diagram (useful since we build it up in steps here)
def all_annotations(node):
    node = attribute(field="nid")(node)
    node = hideAtom(selector="NoneType + int")(node)
    node = align(selector="{x, y : Node | (x != y) and (x.v) = (y.v)}", direction="horizontal")(node)
    node = orientation(selector="{x, y : Node | x->y in (lo + hi)}", directions=["below"])(node)
    node = group(selector="{vr : str, y : Node | @:(vr) = @:(y.v)}", name="nodes")(node)
    node = hideAtom(selector="str")(node)
    node = atomColor(selector="{x: Node | @num:(x.nid) = 0}", value='red')(node)
    node = atomColor(selector="{x: Node | @num:(x.nid) = 1}", value='blue')(node)
    node = atomColor(selector="{x: Node | (@num:(x.nid) != 0) and (@num:(x.nid) != 1)}", value='black')(node)
    node = edgeColor(field="hi", value='green')(node)
    node = edgeColor(field="lo", value='orange')(node)
    node = orientation(selector=f"{{ x, y : Node | x->y in lo and {exclude_terminals('y')}   }}", directions=["left"])(node)
    node = orientation(selector=f"{{ x, y : Node | x->y in hi and {exclude_terminals('y')}   }}", directions=["right"])(node)
    return node

In [None]:


# Else, you could apply all constraints LIKE THIS.

#attribute(field="nid"),
#hideAtom(selector="NoneType + int"),
#align(selector="{x, y : Node | (x != y) and (x.v) = (y.v)}", direction="horizontal"),
#orientation(selector="{x, y : Node | x->y in (lo + hi)}", directions=["below"]),
#group(selector="{vr : str, y : Node | @:(vr) = @:(y.v)}", name="nodes"),
#hideAtom(selector="str"),
#atomColor(selector="{x: Node | @num:(x.nid) = 0}", value='red'),
#atomColor(selector="{x: Node | @num:(x.nid) = 1}", value='blue'),
#atomColor(selector="{x: Node | (@num:(x.nid) != 0) and (@num:(x.nid) != 1)}", value='black'),
#edgeColor(field="hi", value='green'),
#edgeColor(field="lo", value='orange'),
### Subtle overconstraint
#orientation(selector=f"{{ x, y : Node | x->y in lo and {exclude_terminals('y')}   }}", directions=["left"]),
#orientation(selector=f"{{ x, y : Node | x->y in hi and {exclude_terminals('y')}   }}", directions=["right"])
@dataclass
class Node:
    nid: int                    # stable integer id (for debugging / refs)
    v: Optional[str]         # None for constants; otherwise variable name
    lo: Optional["Node"]       # 0-edge (None for constants)
    hi: Optional["Node"]       # 1-edge (None for constants)

    def __init__(self, id: int, v: Optional[str] = None, lo: Optional["Node"] = None, hi: Optional["Node"] = None):
        self.nid = id
        self.v = v
        self.lo = lo
        self.hi = hi

    def is_const(self) -> bool:
        return self.v is None

    def __repr__(self, depth=0, max_depth=8, visited=None):
        if visited is None:
            visited = set()
            
        # Handle cycles and depth limits
        if self.nid in visited or depth > max_depth:
            return f"Node({self.nid}, {self.v}, ...)"
        
        visited.add(self.nid)
        
        # Special cases for constants
        if self.is_const():
            return "TRUE_NODE" if self.nid == 1 else "FALSE_NODE"
        
        # Format children without indentation
        lo_str = "None" if not self.lo else self.lo.__repr__(depth+1, max_depth, visited.copy())
        hi_str = "None" if not self.hi else self.hi.__repr__(depth+1, max_depth, visited.copy())
        
        return f"Node({self.nid}, '{self.v}', {lo_str}, {hi_str})"

# Pre-create constants
FALSE_NODE = Node(id=0, v=None, lo=None, hi=None)
TRUE_NODE  = Node(id=1, v=None, lo=None, hi=None)


class BDD:
    """
    Minimal ROBDD manager where Nodes reference Nodes directly.
    - Canonicality via unique table (v, lo.nid, hi.nid)
    - Append-only variable ordering
    - Minimal boolean ops: neg, and, or
    - Automatic root tracking for operation results
    """
    def __init__(self, ordering: Optional[Iterable[str]] = None):
        self._nodes: Dict[int, Node] = {0: FALSE_NODE, 1: TRUE_NODE}
        self._unique: Dict[Tuple[str, int, int], Node] = {}
        self._var2level: Dict[str, int] = {}
        self._level2var: List[str] = []
        self._next_id: int = 2
        self._ite_cache: Dict[Tuple[int, int, int], Node] = {}
        self._root_nodes: List[Node] = []  # Track actual root nodes
        if ordering:
            for v in ordering:
                self.add_var(v)

    # ---- Readable views ----
    @property
    def nodes(self) -> Dict[int, Node]:
        return self._nodes

    @property
    def variables(self) -> Tuple[str, ...]:
        return tuple(self._level2var)

    def add_root(self, node: Node) -> None:
        """Explicitly track a node as a root (result of a computation)."""
        if node not in self._root_nodes:
            self._root_nodes.append(node)

    @property
    def roots(self) -> Tuple[Node, ...]:
        """Return explicitly tracked root nodes."""
        return tuple(self._root_nodes)

    # ---- Variables ----
    def add_var(self, v: str) -> None:
        if v not in self._var2level:
            self._var2level[v] = len(self._level2var)
            self._level2var.append(v)

    def v(self, name: str) -> Node:
        if name not in self._var2level:
            self.add_var(name)
        return self._mk(name, FALSE_NODE, TRUE_NODE)

    def vars(self, *names: str):
        return tuple(self.v(n) for n in names)

    # ---- Unique constructor ----
    def _mk(self, v: str, lo: Node, hi: Node) -> Node:
        if lo is hi:
            return lo
        key = (v, lo.nid, hi.nid)
        n = self._unique.get(key)
        if n is not None:
            return n
        node = Node(id=self._next_id, v=v, lo=lo, hi=hi)
        self._next_id += 1
        self._unique[key] = node
        self._nodes[node.nid] = node
        return node

    # ---- Helpers for ITE ----
    def _level(self, v: Optional[str]) -> int:
        return self._var2level[v] if v is not None else len(self._level2var) + 1

    def _top(self, u: Node) -> Optional[str]:
        return u.v

    def _split(self, u: Node, v: str) -> Tuple[Node, Node]:
        if u.is_const():
            return (u, u)
        return (u.lo, u.hi) if u.v == v else (u, u)

    # ---- ITE ----
    def ite(self, i: Node, t: Node, e: Node) -> Node:
        key = (i.nid, t.nid, e.nid)
        if key in self._ite_cache:
            return self._ite_cache[key]

        if i is TRUE_NODE:   res = t
        elif i is FALSE_NODE: res = e
        elif t is e:         res = t
        elif t is TRUE_NODE and e is FALSE_NODE:  # ITE(i,1,0) == i
            res = i
        else:
            v = min((self._level(self._top(u)), self._top(u)) for u in (i, t, e))[1]
            assert v is not None
            i0, i1 = self._split(i, v)
            t0, t1 = self._split(t, v)
            e0, e1 = self._split(e, v)
            lo = self.ite(i0, t0, e0)
            hi = self.ite(i1, t1, e1)
            res = self._mk(v, lo, hi)

        self._ite_cache[key] = res
        return res

    # ---- Minimal boolean basis (auto-register results as roots) ----
    def neg(self, u: Node) -> Node:
        result = self.ite(u, FALSE_NODE, TRUE_NODE)
        self.add_root(result)  # Automatically track negation results
        return result

    def op_and(self, a: Node, b: Node) -> Node:
        result = self.ite(a, b, FALSE_NODE)
        self.add_root(result)  # Automatically track AND results
        return result

    def op_or(self, a: Node, b: Node) -> Node:
        result = self.ite(a, TRUE_NODE, b)
        self.add_root(result)  # Automatically track OR results
        return result

    # ---- Evaluate ----
    def evaluate(self, u: Node, assignment: Dict[str, bool]) -> bool:
        while not u.is_const():
            u = u.hi if assignment.get(u.v, False) else u.lo
        return u is TRUE_NODE


# ---------- Tiny wrapper for notebook ergonomics ----------
class B:
    __slots__ = ("mgr", "node")
    def __init__(self, mgr: BDD, node: Node):
        self.mgr, self.node = mgr, node
    def __invert__(self):      return B(self.mgr, self.mgr.neg(self.node))
    def __and__(self, o: "B"): return B(self.mgr, self.mgr.op_and(self.node, o.node))
    def __or__(self, o: "B"):  return B(self.mgr, self.mgr.op_or(self.node, o.node))
    def evaluate(self, env):   return self.mgr.evaluate(self.node, env)
    def __repr__(self):        return f"B(Node id={self.node.nid}, v={self.node.v})"

# Visualizing a Formula

As a concrete case, consider the following boolean formula:

`(~x1 & ~x2 & ~x3) | (x1 & x2) | (x2 & x3)`


If we inspect the BDD value directly in Python, we see a record-style
representation: each node is shown with its field names and references to
children. Such output is typical of default \texttt{__repr__} or serialization
mechanisms: they expose structure as text but do not convey the overall shape
of the BDD.  

In [None]:
# Create BDD manager with variables x1, x2, x3
m2 = BDD(["x1", "x2", "x3"])

# Get the variables
x1_node, x2_node, x3_node = m2.vars("x1", "x2", "x3")
x1, x2, x3 = B(m2, x1_node), B(m2, x2_node), B(m2, x3_node)


f = ~x1 & ~x2 & ~x3 | x1 & x2 | x2 & x3  # Combine with disjunction
print(f.node)

The default SpyTial diagram shows a graph: Nodes in the BDD become boxes, fields become arrows.

In [None]:
diagram(f.node, height=800)

### Step 1: **Hide Irrelevant Detail / Mechanics**


In [None]:


f.node = attribute(field="nid")(f.node)
f.node = hideAtom(selector="NoneType + int")(f.node)
diagram(f.node, height=600)

### Step 2: Layering

Nodes testing the same variable occupy a common layer and are horizontally aligned.


In [None]:
f.node = align(selector="{x, y : Node | (x != y) and (x.v) = (y.v)}", direction="horizontal")(f.node)
diagram(f.node, height=800)

### Step 3: Variable Ordering

Next, the variable order is embodied as spatial order. In this case, we choose vertical order -- parents above children.

In [None]:
f.node = orientation(selector="{x, y : Node | x->y in (lo + hi)}", directions=["below"])(f.node)
diagram(f.node)

## Step 4: Grouping By Variable

In [None]:
f.node = group(selector="{vr : str, y : Node | @:(vr) = @:(y.v)}", name="nodes")(f.node)
f.node = hideAtom(selector="str")(f.node)
diagram(f.node, height=800)


# Conventions


At this stage, the diagram reflects the familiar conception of a BDD, albeit with a lot more detail.
From here, further refinements target presentation rather than structure.


### Coloring / Distinguishing
Nodes and edges can be colored in order to make them distinct.

For example, distinguish terminal nodes, and style edges to distinguish low/high branches.

In [None]:
f.node = atomColor(selector="{x: Node | @num:(x.nid) = 0}", value='red')(f.node)
f.node = atomColor(selector="{x: Node | @num:(x.nid) = 1}", value='blue')(f.node)
# Draw non-constant nodes in black
f.node = atomColor(selector="{x: Node | (@num:(x.nid) != 0) and (@num:(x.nid) != 1)}", value='black')(f.node)

f.node = edgeColor(field="hi", value='green')(f.node)
f.node = edgeColor(field="lo", value='orange')(f.node)
diagram(f.node, height=800)

### Organization

We then orient branching: low edges point left, high edges point right, if not pointing to terminals.
(An astute reader will notice that this can overconstrains layouts; we return
to this below)  



In [None]:

f.node = orientation(selector=f"{{ x, y : Node | x->y in lo and {exclude_terminals('y')}   }}", directions=["left"])(f.node)
f.node = orientation(selector=f"{{ x, y : Node | x->y in hi and {exclude_terminals('y')}   }}", directions=["right"])(f.node)

diagram(f.node, height=800)

# When Constraints are unsatisfiable

The formula

`(~x1 & ~x2 & ((x3 & ~x4) | (~x3 & x4))) | (x1 & x2 & ((x3 & ~x4) | (~x3 & x4)))`

produces a BDD that does not satisfy these constraints.

In [None]:

mgr = BDD(["x1", "x2", "x3", "x4"])

# Get the variables
x1_node, x2_node, x3_node, x4_node = mgr.vars("x1", "x2", "x3", "x4")
x1, x2, x3, x4 = B(mgr, x1_node), B(mgr, x2_node), B(mgr, x3_node), B(mgr, x4_node)


f2 = (~x1 & ~x2 & ((x3 & ~x4) | (~x3 & x4))) | (x1 & x2 & ((x3 & ~x4) | (~x3 & x4)))
f2.node = all_annotations(f2.node)

diagram(f2.node, height=1000)
