# mdd

> Objects representing a multivalued decision diagram.

In [None]:
#| default_exp mdd

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from collections.abc import Hashable, ItemsView, Iterator, KeysView
from itertools import chain
from dataclasses import dataclass, field

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
@dataclass(frozen=True)
class MDDNode:
    """MDDNode represents a single node in the MDD.

    An MDDNode is uniquely identified by its layer and state. The (node) state
    must be a hashable object.

    Parameters
    ----------
    layer : int
        layer the node is in
    state : Hashable
        state associated with node

    """
    layer: int
    state: Hashable

    def __str__(self) -> str:
        return f"N_{self.layer}({self.state})"

`MDDNode` represents a node in the decision diagram. It is uniquely identified by which layer the node is located in, and the node's state. The state should be hashable to ensure we can quickly identify when two nodes in the same layer are "equivalent".

In [None]:
n1 = MDDNode(0, (1, 2, 3))
n2 = MDDNode(1, (1, 3))
n1_copy = MDDNode(0, (1, 2, 3))
print(f"n1 = {n1}\nn2 = {n2}\nn1_copy = {n1_copy}")
print(f"`n1_copy is n1` -> {n1_copy is n1}")
print(f"`n1_copy == n1` -> {n1_copy == n1}")

n1 = N_0((1, 2, 3))
n2 = N_1((1, 3))
n1_copy = N_0((1, 2, 3))
`n1_copy is n1` -> False
`n1_copy == n1` -> True


In [None]:
#| export
@dataclass()
class MDDArc:
    """MDDArc represents a single arc in the MDD.

    An MDDArc is uniquely identified by its head/tail nodes, label, and weight.

    Parameters
    ----------
    label : Hashable
        label of arc (e.g., assigned value)
    weight : float
        weight of arc (e.g., coefficient)
    tail : MDDNode
        tail/source node
    head : MDDNode
        head/destination node

    """
    label: Hashable
    weight: float
    tail: MDDNode
    head: MDDNode

    def __str__(self) -> str:
        return f"A({self.label},{self.weight}:{self.tail},{self.head})"

`MDDArc` represents an arc in the decision diagram. It is uniquely identified by the combaintion of its tail node, head node, and label.

Unlike `MDDNode`, `MDDArc` is NOT immutable; in particular, `MDDArc.weight` may be changed later.

In [None]:
a1_2 = MDDArc(2, 0, n1, n2)
print(f"Original: {a1_2}")
a1_2.weight = 2
print(f"After modification: {a1_2}")

Original: A(2,0:N_0((1, 2, 3)),N_1((1, 3)))
After modification: A(2,2:N_0((1, 2, 3)),N_1((1, 3)))


In [None]:
#| export
@dataclass()
class MDDNodeInfo:
    """MDDNodeInfo represents information associated with an MDDNode.

    Parameters
    ----------
    incoming : list[MDDArc]
        list of incoming arcs (default: [])
    outgoing : list[MDDArc]
        list of outgoing arcs (default: [])

    """
    incoming: list[MDDArc] = field(default_factory=list)
    outgoing: list[MDDArc] = field(default_factory=list)

    def __str__(self) -> str:
        incoming_str = ", ".join(str(a) for a in self.incoming)
        outgoing_str = ", ".join(str(a) for a in self.outgoing)
        return f"<in=[{incoming_str}], out=[{outgoing_str}]>"

`MDDNodeInfo` represents information assicated with a node in the decision diagram, and in particular its incoming and outgoing arcs.

In [None]:
node_info = MDDNodeInfo(incoming=[], outgoing=[a1_2])
print(f"info = {node_info}")

info = <in=[], out=[A(2,2:N_0((1, 2, 3)),N_1((1, 3)))]>


In [None]:
#| export
@dataclass()
class MDD:
    """MDD represents a multivalued decision diagram, or MDD.

    Parameters
    ----------
    name : str
        name of MDD (default: 'mdd')
    nodes : list[dict[MDDNode, MDDNodeInfo]]
        nodes of MDD (default: [])

    """
    name: str = "mdd"
    nodes: list[dict[MDDNode, MDDNodeInfo]] = field(default_factory=list)

    @property
    def numNodeLayers(self) -> int:
        """Number of node layers; equal to number of 'variables' + 1."""
        return len(self.nodes)

    @property
    def numArcLayers(self) -> int:
        """Number of arc layers; equal to number of 'variables'."""
        return len(self.nodes) - 1

    @property
    def widthList(self) -> list[int]:
        """Number of nodes in each layer."""
        return [len(layer) for layer in self.nodes]

    @property
    def maxWidth(self) -> int:
        """Maximum number of nodes in a single node layer."""
        return max(len(layer) for layer in self.nodes)

    def allnodes(self) -> Iterator[MDDNode]:
        """Iterate over all MDDNodes in the MDD."""
        return chain.from_iterable(l.keys() for l in self.nodes)

    def allnodeitems_in_layer(
        self,
        layer: int, # index of layer
    ) -> ItemsView[MDDNode, MDDNodeInfo]:
        """Return a view of all (MDDNode, MDDNodeInfo) pairs in a layer."""
        return self.nodes[layer].items()

    def allnodes_in_layer(
        self,
        layer: int, # index of layer
    ) -> KeysView[MDDNode]:
        """Return a view of all MDDNodes in a layer."""
        return self.nodes[layer]

    def alloutgoingarcs(self) -> Iterator[MDDArc]:
        """Iterate over all outgoing arcs in the MDD."""
        return chain.from_iterable(ui.outgoing for j in range(self.numArcLayers) for ui in self.nodes[j].values())

    def allincomingarcs(self) -> Iterator[MDDArc]:
        """Iterate over all incoming arcs in the MDD."""
        return chain.from_iterable(ui.incoming for j in range(self.numArcLayers) for ui in self.nodes[j+1].values())

    def __str__(
        self,
        showLong: bool = False, # use more vertical space (default: False)
        showIncoming: bool = False, # show incoming arcs (default: False)
    ) -> str:
        """Return a (human-readable) string representation of the MDD."""
        s = '== MDD (' + self.name + ', ' + str(self.numArcLayers) + ' layers) ==\n'
        if showLong:
            # Long form
            s += '# Nodes\n'
            for (j, lyr) in enumerate(self.nodes):
                s += 'L' + str(j) + ':\n'
                for v in lyr:
                    s += '\t' + str(v) + ': <'
                    s += 'in={' + ', '.join(str(a) for a in self.nodes[j][v].incoming) + '}, '
                    s += 'out={' + ', '.join(str(a) for a in self.nodes[j][v].outgoing) + '}'
                    s += '>\n'
            s += '# (Outgoing) Arcs\n'
            s += '\n'.join(str(a) for a in self.alloutgoingarcs())
            if showIncoming:
                s += '\n# (Incoming) Arcs\n'
                s += '\n'.join(str(a) for a in self.allincomingarcs())
        else:
            # Short form
            s += '# Nodes\n'
            for (j, lyr) in enumerate(self.nodes):
                s += 'L' + str(j) + ': '
                s += ', '.join(str(v) for v in self.allnodes_in_layer(j)) + '\n'
            s += '# (Outgoing) Arcs\n'
            s += ', '.join(str(a) for a in self.alloutgoingarcs())
            if showIncoming:
                s += '\n# (Incoming) Arcs\n'
                s += ', '.join(str(a) for a in self.allincomingarcs())
        return s

    def clear(self) -> None:
        """Reset the MDD."""
        self.nodes = []

    def append_new_layer(self) -> None:
        """Append a new layer to the MDD."""
        self.nodes.append(dict())

    def get_node_info(
        self,
        node: MDDNode, # node in MDD
    ) -> MDDNodeInfo:
        """Get `MDDNodeInfo` corresponding to `node`.

        Note this funcation can NOT be used to populate the underlying
        dictionary; it can only be used to reference the object.

        In general, you should use allnodeitems_in_layer(...) if you
        want to update node info in a systematic manner. The author
        recommends only using this function if allnodeitems_in_layer(...)
        cannot be used.
        """
        return self.nodes[node.layer][node]

    def add_arc(
        self,
        newarc: MDDArc, # arc to be added
    ) -> None:
        """Add an arc to the MDD, without any sanity checks.

        The head and tail nodes of the arc should already exist in the MDD.

        """
        self.get_node_info(newarc.tail).outgoing.append(newarc)
        self.get_node_info(newarc.head).incoming.append(newarc)

    def remove_arc(
        self,
        rmvarc: MDDArc, # arc to be removed
    ) -> None:
        """Remove an arc from the MDD, without any sanity checks."""

        self.get_node_info(rmvarc.tail).outgoing.remove(rmvarc)
        self.get_node_info(rmvarc.head).incoming.remove(rmvarc)

    def add_node(
        self,
        newnode: MDDNode, # node to be added
    ) -> None:
        """Add a node to the MDD, without any sanity checks.

        The node's layer should already exist in the MDD.

        NOTE: If an identical node already exists, its incoming and outgoing
        arcs will be ERASED!!!

        """
        self.nodes[newnode.layer][newnode] = MDDNodeInfo()

    def remove_node(
        self,
        rmvnode: MDDNode, # node to be removed
    ) -> None:
        """Remove a node from the MDD, without any sanity checks."""
        for arc in self.get_node_info(rmvnode).incoming:
            self.get_node_info(arc.tail).outgoing.remove(arc)
        for arc in self.get_node_info(rmvnode).outgoing:
            self.get_node_info(arc.head).incoming.remove(arc)
        del self.nodes[rmvnode.layer][rmvnode]

The `MDD` object is the core object. It comes with a few basic operations.

In [None]:
#| hide
def test_basic_mdd_operations() -> None:
    mymdd = MDD()
    test_eq(mymdd.numNodeLayers, 0)
    test_eq(mymdd.numArcLayers, -1)
    test_eq(list(mymdd.allnodes()), [])
    
    mymdd.append_new_layer()
    mynode0 = MDDNode(0, 0)
    mymdd.add_node(mynode0)
    mymdd.append_new_layer()
    test_eq(mymdd.widthList, [1,0])
    
    mynode1 = MDDNode(1, 1)
    mymdd.add_node(mynode1)
    mynode2 = MDDNode(1, 2)
    mymdd.add_node(mynode2)
    test_eq(mymdd.numNodeLayers, 2)
    test_eq(mymdd.widthList, [1,2])
    test_eq(mymdd.maxWidth, 2)
    
    myarc0_1 = MDDArc(0, 0, mynode0, mynode1)
    myarc0_2 = MDDArc(1, 10, mynode0, mynode2)
    mymdd.add_arc(myarc0_1)
    mymdd.add_arc(myarc0_2)
    test_eq(list(mymdd.alloutgoingarcs()), [myarc0_1, myarc0_2])
    test_eq(list(mymdd.allincomingarcs()), [myarc0_1, myarc0_2])
    
    mymdd.remove_node(mynode2)
    assert mynode2 not in list(mymdd.allnodes())
    test_eq(list(mymdd.alloutgoingarcs()), [myarc0_1])
    test_eq(list(mymdd.allincomingarcs()), [myarc0_1])

test_basic_mdd_operations()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()