# mdd

> Objects representing a multivalued decision diagram.

In [None]:
#| default_exp mdd

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from collections.abc import Callable, Collection, Hashable, ItemsView, Iterable, Iterator, KeysView
from itertools import chain
from dataclasses import dataclass, field
from typing import Optional

In [None]:
#| hide
from fastcore.test import test_eq
import traceback

In [None]:
#| export
MDDNodeState = Hashable

@dataclass(frozen=True)
class MDDNode:
    """MDDNode represents a single node in the MDD.

    An MDDNode is uniquely identified by its layer and state. The (node) state
    must be a `collections.abc.Hashable` object.

    Parameters
    ----------
    layer : int
        layer the node is in
    state : MDDNodeState
        state associated with node

    """
    layer: int
    state: MDDNodeState

    def __str__(self) -> str:
        return f"N_{self.layer}({self.state})"

`MDDNode` represents a node in the decision diagram. It is uniquely identified by which layer the node is located in, and the node's state. The state should be hashable to ensure we can quickly identify when two nodes in the same layer are "equivalent".

In [None]:
n1 = MDDNode(0, (1, 2, 3))
n2 = MDDNode(1, (1, 3))
n1_copy = MDDNode(0, (1, 2, 3))
print(f"n1 = {n1}\nn2 = {n2}\nn1_copy = {n1_copy}")
print(f"`n1_copy is n1` -> {n1_copy is n1}")
print(f"`n1_copy == n1` -> {n1_copy == n1}")

n1 = N_0((1, 2, 3))
n2 = N_1((1, 3))
n1_copy = N_0((1, 2, 3))
`n1_copy is n1` -> False
`n1_copy == n1` -> True


In [None]:
#| export
MDDArcLabel = Hashable

@dataclass()
class MDDArc:
    """MDDArc represents a single arc in the MDD.

    An MDDArc is uniquely identified by its head/tail nodes, label, and weight.
    The (arc) label must be a `collections.abc.Hashable` object.

    Parameters
    ----------
    label : MDDArcLabel
        label of arc (e.g., assigned value)
    weight : float
        weight of arc (e.g., coefficient)
    tail : MDDNode
        tail/source node
    head : MDDNode
        head/destination node

    """
    label: MDDArcLabel
    weight: float
    tail: MDDNode
    head: MDDNode

    def __str__(self) -> str:
        return f"A({self.label},{self.weight}:{self.tail},{self.head})"

`MDDArc` represents an arc in the decision diagram. It is uniquely identified by the combaintion of its tail node, head node, and label.

Unlike `MDDNode`, `MDDArc` is NOT immutable; in particular, `MDDArc.weight` may be changed later.

In [None]:
a1_2 = MDDArc(2, 0, n1, n2)
print(f"Original: {a1_2}")
a1_2.weight = 2
print(f"After modification: {a1_2}")

Original: A(2,0:N_0((1, 2, 3)),N_1((1, 3)))
After modification: A(2,2:N_0((1, 2, 3)),N_1((1, 3)))


In [None]:
#| export
@dataclass()
class MDDNodeInfo:
    """MDDNodeInfo represents information associated with an MDDNode.

    Parameters
    ----------
    incoming : list[MDDArc]
        list of incoming arcs (default: [])
    outgoing : list[MDDArc]
        list of outgoing arcs (default: [])

    """
    incoming: list[MDDArc] = field(default_factory=list)
    outgoing: list[MDDArc] = field(default_factory=list)

    def __str__(self) -> str:
        incoming_str = ", ".join(str(a) for a in self.incoming)
        outgoing_str = ", ".join(str(a) for a in self.outgoing)
        return f"<in=[{incoming_str}], out=[{outgoing_str}]>"

`MDDNodeInfo` represents information assicated with a node in the decision diagram, and in particular its incoming and outgoing arcs.

In [None]:
node_info = MDDNodeInfo(incoming=[], outgoing=[a1_2])
print(f"info = {node_info}")

info = <in=[], out=[A(2,2:N_0((1, 2, 3)),N_1((1, 3)))]>


In [None]:
#| export
@dataclass()
class MDD:
    """MDD represents a multivalued decision diagram, or MDD.

    Parameters
    ----------
    name : str
        name of MDD (default: 'mdd')
    nodes : list[dict[MDDNode, MDDNodeInfo]]
        nodes of MDD (default: [])

    """
    name: str = "mdd"
    nodes: list[dict[MDDNode, MDDNodeInfo]] = field(default_factory=list)

    @property
    def numNodeLayers(self) -> int:
        """Number of node layers; equal to number of 'variables' + 1."""
        return len(self.nodes)

    @property
    def numArcLayers(self) -> int:
        """Number of arc layers; equal to number of 'variables'."""
        return len(self.nodes) - 1

    @property
    def widthList(self) -> list[int]:
        """Number of nodes in each layer."""
        return [len(layer) for layer in self.nodes]

    @property
    def maxWidth(self) -> int:
        """Maximum number of nodes in a single node layer."""
        return max(len(layer) for layer in self.nodes)

    def allnodes(self) -> Iterator[MDDNode]:
        """Iterate over all MDDNodes in the MDD."""
        return chain.from_iterable(l.keys() for l in self.nodes)

    def allnodeitems_in_layer(
        self,
        layer: int, # index of layer
    ) -> ItemsView[MDDNode, MDDNodeInfo]:
        """Return a view of all (MDDNode, MDDNodeInfo) pairs in a layer."""
        return self.nodes[layer].items()

    def allnodes_in_layer(
        self,
        layer: int, # index of layer
    ) -> KeysView[MDDNode]:
        """Return a view of all MDDNodes in a layer."""
        return self.nodes[layer]

    def alloutgoingarcs(self) -> Iterator[MDDArc]:
        """Iterate over all outgoing arcs in the MDD."""
        return chain.from_iterable(ui.outgoing for j in range(self.numArcLayers) for ui in self.nodes[j].values())

    def allincomingarcs(self) -> Iterator[MDDArc]:
        """Iterate over all incoming arcs in the MDD."""
        return chain.from_iterable(ui.incoming for j in range(self.numArcLayers) for ui in self.nodes[j+1].values())

    def __str__(
        self,
        showLong: bool = False, # use more vertical space (default: False)
        showIncoming: bool = False, # show incoming arcs (default: False)
    ) -> str:
        """Return a (human-readable) string representation of the MDD."""
        s = '== MDD (' + self.name + ', ' + str(self.numArcLayers) + ' layers) ==\n'
        if showLong:
            # Long form
            s += '# Nodes\n'
            for (j, lyr) in enumerate(self.nodes):
                s += 'L' + str(j) + ':\n'
                for v in lyr:
                    s += '\t' + str(v) + ': <'
                    s += 'in={' + ', '.join(str(a) for a in self.nodes[j][v].incoming) + '}, '
                    s += 'out={' + ', '.join(str(a) for a in self.nodes[j][v].outgoing) + '}'
                    s += '>\n'
            s += '# (Outgoing) Arcs\n'
            s += '\n'.join(str(a) for a in self.alloutgoingarcs())
            if showIncoming:
                s += '\n# (Incoming) Arcs\n'
                s += '\n'.join(str(a) for a in self.allincomingarcs())
        else:
            # Short form
            s += '# Nodes\n'
            for (j, lyr) in enumerate(self.nodes):
                s += 'L' + str(j) + ': '
                s += ', '.join(str(v) for v in self.allnodes_in_layer(j)) + '\n'
            s += '# (Outgoing) Arcs\n'
            s += ', '.join(str(a) for a in self.alloutgoingarcs())
            if showIncoming:
                s += '\n# (Incoming) Arcs\n'
                s += ', '.join(str(a) for a in self.allincomingarcs())
        return s

    def clear(self) -> None:
        """Reset the MDD."""
        self.nodes = []

    def append_new_layers(
        self,
        n: int = 1 # number of layers to append
    ) -> None:
        """Append new layers to the MDD."""
        self.nodes.extend([dict() for _ in range(n)])

    def get_node_info(
        self,
        node: MDDNode, # node in MDD
    ) -> MDDNodeInfo:
        """Get `MDDNodeInfo` corresponding to `node`.

        Note this function can NOT be used to populate the underlying
        dictionary; it can only be used to reference the object.

        In general, you should use allnodeitems_in_layer(...) if you
        want to update node info in a systematic manner. The author
        recommends only using this function if allnodeitems_in_layer(...)
        cannot be used.
        """
        return self.nodes[node.layer][node]

    def add_arc(
        self,
        newarc: MDDArc, # arc to be added
    ) -> None:
        """Add an arc to the MDD, without any sanity checks.

        The head and tail nodes of the arc should already exist in the MDD.
        """
        self.get_node_info(newarc.tail).outgoing.append(newarc)
        self.get_node_info(newarc.head).incoming.append(newarc)

    def add_arcs(
        self,
        newarcs: Iterable[MDDArc], # arcs to be added
    ) -> None:
        """Add arcs to the MDD, without any sanity checks."""
        for arc in newarcs:
            self.add_arc(arc)

    def remove_arc(
        self,
        rmvarc: MDDArc, # arc to be removed
    ) -> None:
        """Remove an arc from the MDD, without any sanity checks."""
        self.get_node_info(rmvarc.tail).outgoing.remove(rmvarc)
        self.get_node_info(rmvarc.head).incoming.remove(rmvarc)

    def remove_arcs(
        self,
        rmvarcs: Iterable[MDDArc], # arcs to be removed
    ) -> None:
        """Remove arcs from the MDD, without any sanity checks."""
        for arc in rmvarcs:
            self.remove_arc(arc)

    def add_node(
        self,
        newnode: MDDNode, # node to be added
    ) -> None:
        """Add a node to the MDD, without any sanity checks.

        The node's layer should already exist in the MDD.

        NOTE: If an identical node already exists, its incoming and outgoing
        arcs will be ERASED!!!
        """
        self.nodes[newnode.layer][newnode] = MDDNodeInfo()

    def add_nodes(
        self,
        newnodes: Iterable[MDDNode], # nodes to be added
    ) -> None:
        """Add nodes to the MDD, without any sanity checks."""
        for node in newnodes:
            self.add_node(node)

    def remove_node(
        self,
        rmvnode: MDDNode, # node to be removed
    ) -> None:
        """Remove a node from the MDD, without any sanity checks."""
        for arc in self.get_node_info(rmvnode).incoming:
            self.get_node_info(arc.tail).outgoing.remove(arc)
        for arc in self.get_node_info(rmvnode).outgoing:
            self.get_node_info(arc.head).incoming.remove(arc)
        del self.nodes[rmvnode.layer][rmvnode]

    def remove_nodes(
        self,
        rmvnodes: Iterable[MDDNode], # nodes to be removed
    ) -> None:
        """Remove nodes from the MDD, without any sanity checks."""
        for node in rmvnodes:
            self.remove_node(node)

    @staticmethod
    def _default_awfun(w: float, ns: MDDNodeState, nt: MDDNodeState, j: int) -> float:
        """By default, do not adjust the weight of the arc."""
        return w

    def merge_nodes(
        self,
        mnodes: Collection[MDDNode],
        mlayer: int,
        nsfun: Callable[[Collection[MDDNodeState], int], MDDNodeState],
        awinfun: Optional[Callable[[float, MDDNodeState, MDDNodeState, int], float]] = None,
        awoutfun: Optional[Callable[[float, MDDNodeState, MDDNodeState, int], float]] = None,
    ) -> MDDNode:
        """Merge specified nodes into a new supernode, and modify arcs appropriately.

        NOTE: All nodes to be merged must be located on the same layer.

        Parameters
        ----------
        mnodes : Collection[MDDNode]
            nodes to be merged together
        mlayer : int
            layer containing all nodes to be merged
        nsfun : Callable[[Collection[MDDNodeState], int], MDDNodeState]
            nsfun(slist,j) returns the node state resulting from merging node states in 'slist' in layer 'j'
        awinfun : Optional[Callable[[float, MDDNodeState, MDDNodeState, int], float]]
            awinfun(w,os,ms,j) returns the adjusted weight of an arc with weight 'w', old head node state 'os', and new head node (i.e., merged supernode in layer 'j') state 'ms';
            if awfun is None (default), the original weight is used
        awoutfun : Optional[Callable[[float, MDDNodeState, MDDNodeState, int], float]]
            awoutfun(w,os,ms,j) returns the adjusted weight of an arc with weight 'w', old tail node state 'os', and new tail node (i.e., merged supernode in layer 'j') state 'ms';
            if awoutfun is None (default), the original weight is used

        Returns
        -------
        MDDNode
            new merged supernode
        """
        if awinfun is None:
            awinfun = self._default_awfun
        if awoutfun is None:
            awoutfun = self._default_awfun

        # Create new supernode, and new incoming / outgoing arcs
        mState = nsfun([v.state for v in mnodes], mlayer)
        mNode = MDDNode(mlayer, mState)

        newIncoming = []
        newOutgoing = []
        for v in mnodes:
            node_info = self.get_node_info(v)
            for inarc in node_info.incoming:
                new_inarc = MDDArc(inarc.label, awinfun(inarc.weight, inarc.head.state, mState, mlayer), inarc.tail, mNode)
                newIncoming.append(new_inarc)
            for outarc in node_info.outgoing:
                new_outarc = MDDArc(outarc.label, awoutfun(outarc.weight, outarc.tail.state, mState, mlayer), mNode, outarc.head)
                newOutgoing.append(new_outarc)

        # Delete merged nodes
        for v in mnodes:
            self.remove_node(v)

        # Add supernode and its arcs to MDD
        self.add_node(mNode)
        for inarc in newIncoming:
            self.add_arc(inarc)
        for outarc in newOutgoing:
            self.add_arc(outarc)

        # Return new merged supernode
        return mNode

The `MDD` is the core object of this module, representing a multivalued decision diagram. An MDD starts out empty, with 0 node layers and -1 arc layers. In addition to basic getter methods, it has a `__str__` method for pretty printing its current state.

In [None]:
mdd0 = MDD()
print(f"mdd0 = {mdd0}")
print(f"numNodeLayers = {mdd0.numNodeLayers}")
print(f"numArcLayers = {mdd0.numArcLayers}")

mdd0 = == MDD (mdd, -1 layers) ==
# Nodes
# (Outgoing) Arcs

numNodeLayers = 0
numArcLayers = -1


In [None]:
show_doc(MDD.clear)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L184){target="_blank" style="float:right; font-size:smaller"}

### MDD.clear

>      MDD.clear ()

*Reset the MDD.*

In [None]:
show_doc(MDD.append_new_layers)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L188){target="_blank" style="float:right; font-size:smaller"}

### MDD.append_new_layers

>      MDD.append_new_layers (n:int=1)

*Append new layers to the MDD.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| n | int | 1 | number of layers to append |
| **Returns** | **None** |  |  |

In [None]:
show_doc(MDD.get_node_info)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L195){target="_blank" style="float:right; font-size:smaller"}

### MDD.get_node_info

>      MDD.get_node_info (node:__main__.MDDNode)

*Get `MDDNodeInfo` corresponding to `node`.

Note this function can NOT be used to populate the underlying
dictionary; it can only be used to reference the object.

In general, you should use allnodeitems_in_layer(...) if you
want to update node info in a systematic manner. The author
recommends only using this function if allnodeitems_in_layer(...)
cannot be used.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| node | MDDNode | node in MDD |
| **Returns** | **MDDNodeInfo** |  |

In [None]:
show_doc(MDD.add_arc)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L211){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_arc

>      MDD.add_arc (newarc:__main__.MDDArc)

*Add an arc to the MDD, without any sanity checks.

The head and tail nodes of the arc should already exist in the MDD.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newarc | MDDArc | arc to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_arcs)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L222){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_arcs

>      MDD.add_arcs (newarcs:collections.abc.Iterable[__main__.MDDArc])

*Add arcs to the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newarcs | Iterable | arcs to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_arc)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L230){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_arc

>      MDD.remove_arc (rmvarc:__main__.MDDArc)

*Remove an arc from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvarc | MDDArc | arc to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_arcs)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L238){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_arcs

>      MDD.remove_arcs (rmvarcs:collections.abc.Iterable[__main__.MDDArc])

*Remove arcs from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvarcs | Iterable | arcs to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_node)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L246){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_node

>      MDD.add_node (newnode:__main__.MDDNode)

*Add a node to the MDD, without any sanity checks.

The node's layer should already exist in the MDD.

NOTE: If an identical node already exists, its incoming and outgoing
arcs will be ERASED!!!*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newnode | MDDNode | node to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L259){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_nodes

>      MDD.add_nodes (newnodes:collections.abc.Iterable[__main__.MDDNode])

*Add nodes to the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newnodes | Iterable | nodes to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_node)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L267){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_node

>      MDD.remove_node (rmvnode:__main__.MDDNode)

*Remove a node from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvnode | MDDNode | node to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L278){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_nodes

>      MDD.remove_nodes (rmvnodes:collections.abc.Iterable[__main__.MDDNode])

*Remove nodes from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvnodes | Iterable | nodes to be removed |
| **Returns** | **None** |  |

To populate the MDD, you must first add the necessary layers, then the nodes and arcs. Note that the basic functions do not perform any sanity checks, so it's the programmer's responsibility to avoid errors.

In [None]:
mdd0.append_new_layers(n=2)
mdd0.add_nodes([n1, n2])
mdd0.add_arc(a1_2)
print(f"mdd0 = {mdd0}")

mdd0 = == MDD (mdd, 1 layers) ==
# Nodes
L0: N_0((1, 2, 3))
L1: N_1((1, 3))
# (Outgoing) Arcs
A(2,2:N_0((1, 2, 3)),N_1((1, 3)))


In [None]:
#| hide
def test_basic_mdd_operations() -> None:
    mymdd = MDD()
    test_eq(mymdd.numNodeLayers, 0)
    test_eq(mymdd.numArcLayers, -1)
    test_eq(list(mymdd.allnodes()), [])
    
    mymdd.append_new_layers()
    mynode0 = MDDNode(0, 0)
    mymdd.add_node(mynode0)
    mymdd.append_new_layers()
    test_eq(mymdd.widthList, [1,0])
    
    mynode1 = MDDNode(1, 1)
    mymdd.add_node(mynode1)
    mynode2 = MDDNode(1, 2)
    mymdd.add_node(mynode2)
    test_eq(mymdd.numNodeLayers, 2)
    test_eq(mymdd.widthList, [1,2])
    test_eq(mymdd.maxWidth, 2)
    
    myarc0_1 = MDDArc(0, 0, mynode0, mynode1)
    myarc0_2 = MDDArc(1, 10, mynode0, mynode2)
    mymdd.add_arc(myarc0_1)
    mymdd.add_arc(myarc0_2)
    test_eq(list(mymdd.alloutgoingarcs()), [myarc0_1, myarc0_2])
    test_eq(list(mymdd.allincomingarcs()), [myarc0_1, myarc0_2])
    print(mymdd)
    
    mymdd.remove_node(mynode2)
    assert mynode2 not in list(mymdd.allnodes())
    test_eq(list(mymdd.alloutgoingarcs()), [myarc0_1])
    test_eq(list(mymdd.allincomingarcs()), [myarc0_1])

test_basic_mdd_operations()

== MDD (mdd, 1 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1), N_1(2)
# (Outgoing) Arcs
A(0,0:N_0(0),N_1(1)), A(1,10:N_0(0),N_1(2))


In [None]:
show_doc(MDD.merge_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L291){target="_blank" style="float:right; font-size:smaller"}

### MDD.merge_nodes

>      MDD.merge_nodes (mnodes:collections.abc.Collection[__main__.MDDNode],
>                       mlayer:int, nsfun:collections.abc.Callable[[collections.
>                       abc.Collection[collections.abc.Hashable],int],collection
>                       s.abc.Hashable], awinfun:Optional[collections.abc.Callab
>                       le[[float,collections.abc.Hashable,collections.abc.Hasha
>                       ble,int],float]]=None, awoutfun:Optional[collections.abc
>                       .Callable[[float,collections.abc.Hashable,collections.ab
>                       c.Hashable,int],float]]=None)

*Merge specified nodes into a new supernode, and modify arcs appropriately.

NOTE: All nodes to be merged must be located on the same layer.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| mnodes | Collection |  | nodes to be merged together |
| mlayer | int |  | layer containing all nodes to be merged |
| nsfun | Callable |  | nsfun(slist,j) returns the node state resulting from merging node states in 'slist' in layer 'j' |
| awinfun | Optional | None | awinfun(w,os,ms,j) returns the adjusted weight of an arc with weight 'w', old head node state 'os', and new head node (i.e., merged supernode in layer 'j') state 'ms';<br>if awfun is None (default), the original weight is used |
| awoutfun | Optional | None | awoutfun(w,os,ms,j) returns the adjusted weight of an arc with weight 'w', old tail node state 'os', and new tail node (i.e., merged supernode in layer 'j') state 'ms';<br>if awoutfun is None (default), the original weight is used |
| **Returns** | **MDDNode** |  | **new merged supernode** |

You can merge nodes in an MDD to reduce its size. Depending on the merging rule, this will typically result in an relaxation or restriction of the original MDD. The basic `merge_nodes()` function is configured based on node state functions, and is quite flexible. However, it can only merge nodes on the same layer.

In [None]:
def test_merge_nodes() -> None:
    mymdd = MDD()
    mymdd.append_new_layers()
    n0 = MDDNode(0, 0)
    mymdd.add_node(n0)
    mymdd.append_new_layers()
    n1 = MDDNode(1, 1)
    n2 = MDDNode(1, 2)
    n3 = MDDNode(1, 3)
    mymdd.add_nodes([n1, n2, n3])
    mymdd.append_new_layers()
    n4 = MDDNode(2, 4)
    mymdd.add_node(n4)
    a0_1 = MDDArc(1, 1, n0, n1)
    a0_2 = MDDArc(2, 2, n0, n2)
    a0_3 = MDDArc(3, 3, n0, n3)
    a1_4 = MDDArc(1, -1, n1, n4)
    a2_4 = MDDArc(2, -2, n2, n4)
    a3_4 = MDDArc(3, -3, n3, n4)
    mymdd.add_arcs([a0_1, a0_2, a0_3, a1_4, a2_4, a3_4])
    print(mymdd)
    print()
    mymdd.merge_nodes([n1, n2, n3], 1, lambda ns, l: max(s for s in ns))
    print(mymdd)

test_merge_nodes()

== MDD (mdd, 2 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1), N_1(2), N_1(3)
L2: N_2(4)
# (Outgoing) Arcs
A(1,1:N_0(0),N_1(1)), A(2,2:N_0(0),N_1(2)), A(3,3:N_0(0),N_1(3)), A(1,-1:N_1(1),N_2(4)), A(2,-2:N_1(2),N_2(4)), A(3,-3:N_1(3),N_2(4))

== MDD (mdd, 2 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(3)
L2: N_2(4)
# (Outgoing) Arcs
A(1,1:N_0(0),N_1(3)), A(2,2:N_0(0),N_1(3)), A(3,3:N_0(0),N_1(3)), A(1,-1:N_1(3),N_2(4)), A(2,-2:N_1(3),N_2(4)), A(3,-3:N_1(3),N_2(4))


In [None]:
#| hide
def test_merge_one_node() -> None:
    mymdd = MDD()
    mymdd.append_new_layers()
    mymdd.add_node(MDDNode(0, 0))
    mymdd.append_new_layers()
    n1 = MDDNode(1, 1)
    n2 = MDDNode(1, 2)
    mymdd.add_nodes([n1, n2])
    mymdd.merge_nodes([n1], 1, lambda ns, l: max(s for s in ns))
    print(mymdd)

test_merge_one_node()

== MDD (mdd, 1 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(2), N_1(1)
# (Outgoing) Arcs



While `MDD` does not explicitly forbid "skip arcs" (arcs that connect nodes in non-adjacent layers), most functionality is designed assuming there are no skip arcs.

In [None]:
def test_merge_skip_arcs_fails() -> None:
    mymdd = MDD()
    mymdd.append_new_layers(4)
    n0 = MDDNode(0, 0)
    n1 = MDDNode(1, 1)
    n2 = MDDNode(2, 2)
    n3 = MDDNode(3, 3)
    mymdd.add_nodes([n0, n1, n2, n3])
    mymdd.add_arcs([MDDArc(1, 1, n0, n1), MDDArc(2, 2, n1, n2), MDDArc(2, 2, n2, n3)])
    print(mymdd)
    try:
        mymdd.merge_nodes([n1, n2], 1, lambda ns, l: max(ns))
        print(mymdd)
        raise ValueError("Expected KeyError")
    except KeyError as e:
        print("Success: Caught KeyError as expected")
        print(e)

test_merge_skip_arcs_fails()

== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1)
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(1,1:N_0(0),N_1(1)), A(2,2:N_1(1),N_2(2)), A(2,2:N_2(2),N_3(3))
Success: Caught KeyError as expected
MDDNode(layer=1, state=1)


In [None]:
def test_merge_skip_arcs_works_sometimes() -> None:
    mymdd = MDD()
    mymdd.append_new_layers(4)
    n0 = MDDNode(0, 0)
    n1 = MDDNode(1, 1)
    n2 = MDDNode(2, 2)
    n3 = MDDNode(3, 3)
    mymdd.add_nodes([n0, n1, n2, n3])
    mymdd.add_arcs([MDDArc(1, 1, n0, n1), MDDArc(2, 2, n0, n2), MDDArc(2, 2, n1, n3), MDDArc(1, 1, n2, n3)])
    print(mymdd)
    mymdd.merge_nodes([n1, n2], 2, lambda ns, l: max(ns))
    print(mymdd)

test_merge_skip_arcs_works_sometimes()

== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1)
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(1,1:N_0(0),N_1(1)), A(2,2:N_0(0),N_2(2)), A(2,2:N_1(1),N_3(3)), A(1,1:N_2(2),N_3(3))
== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: 
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(1,1:N_0(0),N_2(2)), A(2,2:N_0(0),N_2(2)), A(2,2:N_2(2),N_3(3)), A(1,1:N_2(2),N_3(3))


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()