# mdd

> Objects representing a multivalued decision diagram.

In [None]:
#| default_exp mdd

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from collections.abc import Callable, Collection, Hashable, ItemsView, Iterable, Iterator, KeysView, Mapping
from itertools import chain
from dataclasses import dataclass, field
from typing import Optional

In [None]:
#| hide
import ipytest
import pytest

In [None]:
#| hide
ipytest.autoconfig(addopts=("-qq", "--color=yes"))

## Classes

In [None]:
#| export
MDDNodeState = Hashable

@dataclass(frozen=True)
class MDDNode:
    """MDDNode represents a single node in the MDD.

    An MDDNode is uniquely identified by its layer and state. The (node) state
    must be a `collections.abc.Hashable` object.

    Parameters
    ----------
    layer : int
        layer the node is in
    state : MDDNodeState
        state associated with node

    """
    layer: int
    state: MDDNodeState

    def __str__(self) -> str:
        return f"N_{self.layer}({self.state})"

`MDDNode` represents a node in the decision diagram. It is uniquely identified by which layer the node is located in, and the node's state. The state should be hashable to ensure we can quickly identify when two nodes in the same layer are "equivalent".

In [None]:
n1 = MDDNode(0, (1, 2, 3))
n2 = MDDNode(1, (1, 3))
n1_copy = MDDNode(0, (1, 2, 3))
print(f"n1 = {n1}\nn2 = {n2}\nn1_copy = {n1_copy}")
print(f"`n1_copy is n1` -> {n1_copy is n1}")
print(f"`n1_copy == n1` -> {n1_copy == n1}")

n1 = N_0((1, 2, 3))
n2 = N_1((1, 3))
n1_copy = N_0((1, 2, 3))
`n1_copy is n1` -> False
`n1_copy == n1` -> True


In [None]:
#| export
MDDArcLabel = Hashable

@dataclass(frozen=True)
class MDDArc:
    """MDDArc represents a single arc in the MDD.

    An MDDArc is uniquely identified by its head/tail nodes, label, and weight.
    The (arc) label must be a `collections.abc.Hashable` object.

    Parameters
    ----------
    label : MDDArcLabel
        label of arc (e.g., assigned value)
    tail : MDDNode
        tail/source node
    head : MDDNode
        head/destination node

    """
    label: MDDArcLabel
    tail: MDDNode
    head: MDDNode

    def __str__(self) -> str:
        return f"A({self.label}:{self.tail},{self.head})"

`MDDArc` represents an arc in the decision diagram. It is uniquely identified by the combination of its tail node, head node, and label.

In [None]:
a1_2 = MDDArc(2, n1, n2)
print(f"a1_2 = {a1_2}")

a1_2 = A(2:N_0((1, 2, 3)),N_1((1, 3)))


In [None]:
#| export
@dataclass()
class MDDNodeData:
    """MDDNodeData represents information associated with an MDDNode.

    Parameters
    ----------
    incoming : list[MDDArc]
        list of incoming arcs (default: [])
    outgoing : list[MDDArc]
        list of outgoing arcs (default: [])

    """
    incoming: list[MDDArc] = field(default_factory=list)
    outgoing: list[MDDArc] = field(default_factory=list)

    def __str__(self) -> str:
        incoming_str = ", ".join(str(a) for a in self.incoming)
        outgoing_str = ", ".join(str(a) for a in self.outgoing)
        return f"<in=[{incoming_str}], out=[{outgoing_str}]>"

`MDDNodeData` represents information associated with a node in the decision diagram, in particular its incoming and outgoing arcs.

In [None]:
node_data = MDDNodeData(incoming=[], outgoing=[a1_2])
print(f"node_data = {node_data}")

node_data = <in=[], out=[A(2:N_0((1, 2, 3)),N_1((1, 3)))]>


In [None]:
#| export
@dataclass()
class MDDArcData:
    """MDDArcData represents information associated with an MDDArc.

    Parameters
    ----------
    weight : float
        weight of arc (default: 0.0)

    """
    weight: float = 0.0

    def __str__(self) -> str:
        return f"<{self.weight}>"

In [None]:
#| export
@dataclass()
class MDD:
    """MDD represents a multivalued decision diagram, or MDD.

    Parameters
    ----------
    name : str
        name of MDD (default: 'mdd')
    nodes : list[dict[MDDNode, MDDNodeData]]
        nodes of MDD (default: [])
    arcs: dict[MDDArc, MDDArcData]
        arcs of MDD (default: dict())

    """
    name: str = "mdd"
    nodes: list[dict[MDDNode, MDDNodeData]] = field(default_factory=list)
    arcs: dict[MDDArc, MDDArcData] = field(default_factory=dict)

    @property
    def num_node_layers(self) -> int:
        """Number of node layers; equal to number of 'variables' + 1."""
        return len(self.nodes)

    @property
    def num_arc_layers(self) -> int:
        """Number of arc layers; equal to number of 'variables'."""
        return len(self.nodes) - 1

    @property
    def width_list(self) -> list[int]:
        """Number of nodes in each layer."""
        return [len(layer) for layer in self.nodes]

    @property
    def max_width(self) -> int:
        """Maximum number of nodes in a single node layer."""
        return max(len(layer) for layer in self.nodes)

    def all_nodes(self) -> Iterator[MDDNode]:
        """Iterate over all MDDNodes in the MDD."""
        return chain.from_iterable(l.keys() for l in self.nodes)

    def all_nodeitems_in_layer(
        self,
        layer: int, # index of layer
    ) -> ItemsView[MDDNode, MDDNodeData]:
        """Return a view of all (MDDNode, MDDNodeData) pairs in a layer."""
        return self.nodes[layer].items()

    def allnodes_in_layer(
        self,
        layer: int, # index of layer
    ) -> KeysView[MDDNode]:
        """Return a view of all MDDNodes in a layer."""
        return self.nodes[layer]

    def all_outgoing_arcs(self) -> Iterator[MDDArc]:
        """Iterate over all outgoing arcs in the MDD."""
        return chain.from_iterable(ui.outgoing for j in range(self.num_arc_layers) for ui in self.nodes[j].values())

    def all_incoming_arcs(self) -> Iterator[MDDArc]:
        """Iterate over all incoming arcs in the MDD."""
        return chain.from_iterable(ui.incoming for j in range(self.num_arc_layers) for ui in self.nodes[j+1].values())

    def __str__(
        self,
        show_long: bool = False, # use more vertical space (default: False)
        show_incoming: bool = False, # show incoming arcs (default: False)
    ) -> str:
        """Return a (human-readable) string representation of the MDD."""
        s = f"== MDD ({self.name}, {self.num_arc_layers} layers) ==\n"
        if show_long:
            # Long form
            s += "# Nodes\n"
            for (j, lyr) in enumerate(self.nodes):
                s += f"L{j}:\n"
                for v in lyr:
                    s += f"\t{v}: <"
                    s += "in={" + ", ".join(str(a) for a in self.nodes[j][v].incoming) + "}, "
                    s += "out={" + ", ".join(str(a) for a in self.nodes[j][v].outgoing) + "}"
                    s += ">\n"
            s += "# (Outgoing) Arcs\n"
            s += "\n".join(f"{a}:{self.arcs[a]}" for a in self.all_outgoing_arcs())
            if show_incoming:
                s += "\n# (Incoming) Arcs\n"
                s += "\n".join(f"{a}:{self.arcs[a]}" for a in self.all_incoming_arcs())
        else:
            # Short form
            s += "# Nodes\n"
            for (j, lyr) in enumerate(self.nodes):
                s += f"L{j}: "
                s += ", ".join(str(v) for v in self.allnodes_in_layer(j)) + "\n"
            s += "# (Outgoing) Arcs\n"
            s += ", ".join(f"{a}:{self.arcs[a]}" for a in self.all_outgoing_arcs())
            if show_incoming:
                s += "\n# (Incoming) Arcs\n"
                s += ", ".join(f"{a}:{self.arcs[a]}" for a in self.all_incoming_arcs())
        return s

    def clear(self) -> None:
        """Reset the MDD."""
        self.nodes = []

    def append_new_layers(
        self,
        n: int = 1 # number of layers to append
    ) -> None:
        """Append new layers to the MDD."""
        self.nodes.extend([dict() for _ in range(n)])

    def get_node_data(
        self,
        node: MDDNode, # node in MDD
    ) -> MDDNodeData:
        """Get `MDDNodeData` corresponding to `node`.

        Note this function can NOT be used to populate the underlying
        dictionary; it can only be used to reference the object.

        In general, you should use all_nodeitems_in_layer(...) if you
        want to update node data in a systematic manner. The author
        recommends only using this function if all_nodeitems_in_layer(...)
        cannot be used.
        """
        return self.nodes[node.layer][node]

    def add_arc(
        self,
        newarc: MDDArc, # arc to be added
        newarc_data: MDDArcData, # data for arc to be added
    ) -> None:
        """Add an arc to the MDD, without any sanity checks.

        The head and tail nodes of the arc should already exist in the MDD.
        """
        self.get_node_data(newarc.tail).outgoing.append(newarc)
        self.get_node_data(newarc.head).incoming.append(newarc)
        self.arcs[newarc] = newarc_data

    def add_arcs(
        self,
        newarcs: Mapping[MDDArc, MDDArcData], # arcs and their data to be added
    ) -> None:
        """Add arcs to the MDD, without any sanity checks."""
        for arc, arc_data in newarcs.items():
            self.add_arc(arc, arc_data)

    def remove_arc(
        self,
        rmvarc: MDDArc, # arc to be removed
    ) -> None:
        """Remove an arc from the MDD, without any sanity checks."""
        self.get_node_data(rmvarc.tail).outgoing.remove(rmvarc)
        self.get_node_data(rmvarc.head).incoming.remove(rmvarc)
        del self.arcs[rmvarc]

    def remove_arcs(
        self,
        rmvarcs: Iterable[MDDArc], # arcs to be removed
    ) -> None:
        """Remove arcs from the MDD, without any sanity checks."""
        for arc in rmvarcs:
            self.remove_arc(arc)

    def add_node(
        self,
        newnode: MDDNode, # node to be added
    ) -> None:
        """Add a node to the MDD, without any sanity checks.

        The node's layer should already exist in the MDD.

        NOTE: If an identical node already exists, its incoming and outgoing
        arcs will be ERASED!!!
        """
        self.nodes[newnode.layer][newnode] = MDDNodeData()

    def add_nodes(
        self,
        newnodes: Iterable[MDDNode], # nodes to be added
    ) -> None:
        """Add nodes to the MDD, without any sanity checks."""
        for node in newnodes:
            self.add_node(node)

    def remove_node(
        self,
        rmvnode: MDDNode, # node to be removed
    ) -> None:
        """Remove a node from the MDD, without any sanity checks."""
        for arc in self.get_node_data(rmvnode).incoming:
            self.get_node_data(arc.tail).outgoing.remove(arc)
        for arc in self.get_node_data(rmvnode).outgoing:
            self.get_node_data(arc.head).incoming.remove(arc)
        del self.nodes[rmvnode.layer][rmvnode]

    def remove_nodes(
        self,
        rmvnodes: Iterable[MDDNode], # nodes to be removed
    ) -> None:
        """Remove nodes from the MDD, without any sanity checks."""
        for node in rmvnodes:
            self.remove_node(node)

    @staticmethod
    def _default_adfun(d: MDDArcData, ns: MDDNodeState, nt: MDDNodeState, j: int) -> MDDArcData:
        """By default, use the original arc data."""
        return d

    def merge_nodes(
        self,
        mnodes: Collection[MDDNode],
        mlayer: int,
        nsfun: Callable[[Collection[MDDNodeState], int], MDDNodeState],
        adinfun: Optional[Callable[[MDDArcData, MDDNodeState, MDDNodeState, int], MDDArcData]] = None,
        adoutfun: Optional[Callable[[MDDArcData, MDDNodeState, MDDNodeState, int], MDDArcData]] = None,
    ) -> MDDNode:
        """Merge specified nodes into a new supernode, and modify arcs appropriately.

        NOTE: All nodes to be merged must be located on the same layer.

        Parameters
        ----------
        mnodes : Collection[MDDNode]
            nodes to be merged together
        mlayer : int
            layer containing all nodes to be merged
        nsfun : Callable[[Collection[MDDNodeState], int], MDDNodeState]
            nsfun(slist,j) returns the node state resulting from merging node states in 'slist' in layer 'j'
        adinfun : Optional[Callable[[MDDArcData, MDDNodeState, MDDNodeState, int], MDDArcData]]
            adinfun(d,os,ms,j) returns the adjusted data of an arc with data 'd', old head node state 'os', and new head node (i.e., merged supernode in layer 'j') state 'ms';
            if adinfun is None (default), the original arc data is used
        adoutfun : Optional[Callable[[MDDArcData, MDDNodeState, MDDNodeState, int], MDDArcData]]
            adoutfun(d,os,ms,j) returns the adjusted data of an arc with data 'd', old tail node state 'os', and new tail node (i.e., merged supernode in layer 'j') state 'ms';
            if adoutfun is None (default), the original arc data is used

        Returns
        -------
        MDDNode
            new merged supernode
        """
        if adinfun is None:
            adinfun = self._default_adfun
        if adoutfun is None:
            adoutfun = self._default_adfun

        # Create new supernode, and new incoming / outgoing arcs
        mState = nsfun([v.state for v in mnodes], mlayer)
        mNode = MDDNode(mlayer, mState)

        newIncoming = dict()
        newOutgoing = dict()
        for v in mnodes:
            node_data = self.get_node_data(v)
            for inarc in node_data.incoming:
                new_inarc = MDDArc(inarc.label, inarc.tail, mNode)
                new_inarc_data = adinfun(self.arcs[inarc], inarc.head.state, mState, mlayer)
                newIncoming[new_inarc] = new_inarc_data
            for outarc in node_data.outgoing:
                new_outarc = MDDArc(outarc.label, mNode, outarc.head)
                new_outarc_data = adoutfun(self.arcs[outarc], outarc.tail.state, mState, mlayer)
                newOutgoing[new_outarc] = new_outarc_data

        # Delete merged nodes
        for v in mnodes:
            self.remove_node(v)

        # Add supernode and its arcs to MDD
        self.add_node(mNode)
        self.add_arcs(newIncoming)
        self.add_arcs(newOutgoing)

        # Return new merged supernode
        return mNode

The `MDD` is the core object of this module, representing a multivalued decision diagram. An MDD starts out empty, with 0 node layers and -1 arc layers. In addition to basic getter methods, it has a `__str__` method for pretty printing its current state.

In [None]:
mdd0 = MDD()
print(f"mdd0 = {mdd0}")
print(f"num_node_layers = {mdd0.num_node_layers}")
print(f"num_arc_layers = {mdd0.num_arc_layers}")

mdd0 = == MDD (mdd, -1 layers) ==
# Nodes
# (Outgoing) Arcs

num_node_layers = 0
num_arc_layers = -1


## MDD: Basic operations

In [None]:
show_doc(MDD.clear)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L200){target="_blank" style="float:right; font-size:smaller"}

### MDD.clear

>      MDD.clear ()

*Reset the MDD.*

In [None]:
show_doc(MDD.append_new_layers)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L204){target="_blank" style="float:right; font-size:smaller"}

### MDD.append_new_layers

>      MDD.append_new_layers (n:int=1)

*Append new layers to the MDD.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| n | int | 1 | number of layers to append |
| **Returns** | **None** |  |  |

In [None]:
show_doc(MDD.get_node_data)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L211){target="_blank" style="float:right; font-size:smaller"}

### MDD.get_node_data

>      MDD.get_node_data (node:__main__.MDDNode)

*Get `MDDNodeData` corresponding to `node`.

Note this function can NOT be used to populate the underlying
dictionary; it can only be used to reference the object.

In general, you should use all_nodeitems_in_layer(...) if you
want to update node data in a systematic manner. The author
recommends only using this function if all_nodeitems_in_layer(...)
cannot be used.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| node | MDDNode | node in MDD |
| **Returns** | **MDDNodeData** |  |

In [None]:
show_doc(MDD.add_arc)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L227){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_arc

>      MDD.add_arc (newarc:__main__.MDDArc, newarc_data:__main__.MDDArcData)

*Add an arc to the MDD, without any sanity checks.

The head and tail nodes of the arc should already exist in the MDD.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newarc | MDDArc | arc to be added |
| newarc_data | MDDArcData | data for arc to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_arcs)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L240){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_arcs

>      MDD.add_arcs
>                    (newarcs:collections.abc.Mapping[__main__.MDDArc,__main__.M
>                    DDArcData])

*Add arcs to the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newarcs | Mapping | arcs and their data to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_arc)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L248){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_arc

>      MDD.remove_arc (rmvarc:__main__.MDDArc)

*Remove an arc from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvarc | MDDArc | arc to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_arcs)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L257){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_arcs

>      MDD.remove_arcs (rmvarcs:collections.abc.Iterable[__main__.MDDArc])

*Remove arcs from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvarcs | Iterable | arcs to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_node)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L265){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_node

>      MDD.add_node (newnode:__main__.MDDNode)

*Add a node to the MDD, without any sanity checks.

The node's layer should already exist in the MDD.

NOTE: If an identical node already exists, its incoming and outgoing
arcs will be ERASED!!!*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newnode | MDDNode | node to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.add_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L278){target="_blank" style="float:right; font-size:smaller"}

### MDD.add_nodes

>      MDD.add_nodes (newnodes:collections.abc.Iterable[__main__.MDDNode])

*Add nodes to the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| newnodes | Iterable | nodes to be added |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_node)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L286){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_node

>      MDD.remove_node (rmvnode:__main__.MDDNode)

*Remove a node from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvnode | MDDNode | node to be removed |
| **Returns** | **None** |  |

In [None]:
show_doc(MDD.remove_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L297){target="_blank" style="float:right; font-size:smaller"}

### MDD.remove_nodes

>      MDD.remove_nodes (rmvnodes:collections.abc.Iterable[__main__.MDDNode])

*Remove nodes from the MDD, without any sanity checks.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| rmvnodes | Iterable | nodes to be removed |
| **Returns** | **None** |  |

To populate the MDD, you must first add the necessary layers, then the nodes and arcs. Note that the basic functions do not perform any sanity checks, so it's the programmer's responsibility to avoid errors.

In [None]:
mdd0.append_new_layers(n=2)
mdd0.add_nodes([n1, n2])
mdd0.add_arc(a1_2, MDDArcData(0))
print(f"mdd0 = {mdd0}")

mdd0 = == MDD (mdd, 1 layers) ==
# Nodes
L0: N_0((1, 2, 3))
L1: N_1((1, 3))
# (Outgoing) Arcs
A(2:N_0((1, 2, 3)),N_1((1, 3))):<0>


In [None]:
%%ipytest
#| hide
def test_basic_mdd_operations() -> None:
    mymdd = MDD()
    assert mymdd.num_node_layers == 0
    assert mymdd.num_arc_layers == -1
    assert list(mymdd.all_nodes()) == []
    
    mymdd.append_new_layers()
    mynode0 = MDDNode(0, 0)
    mymdd.add_node(mynode0)
    mymdd.append_new_layers()
    assert mymdd.width_list == [1,0]
    
    mynode1 = MDDNode(1, 1)
    mymdd.add_node(mynode1)
    mynode2 = MDDNode(1, 2)
    mymdd.add_node(mynode2)
    assert mymdd.num_node_layers == 2
    assert mymdd.width_list == [1,2]
    assert mymdd.max_width == 2
    
    myarc0_1 = MDDArc(0, mynode0, mynode1)
    myarc0_2 = MDDArc(1, mynode0, mynode2)
    mymdd.add_arc(myarc0_1, MDDArcData(0))
    mymdd.add_arc(myarc0_2, MDDArcData(10))
    assert list(mymdd.all_outgoing_arcs()) == [myarc0_1, myarc0_2]
    assert list(mymdd.all_incoming_arcs()) == [myarc0_1, myarc0_2]
    
    mymdd.remove_node(mynode2)
    assert mynode2 not in list(mymdd.all_nodes())
    assert list(mymdd.all_outgoing_arcs()) == [myarc0_1]
    assert list(mymdd.all_incoming_arcs()) == [myarc0_1]

test_basic_mdd_operations()

[32m.[0m[32m                                                                                            [100%][0m


In [None]:
show_doc(MDD.merge_nodes)

---

[source](https://github.com/rkimura47/python-mdd/blob/main/python_mdd/mdd.py#L310){target="_blank" style="float:right; font-size:smaller"}

### MDD.merge_nodes

>      MDD.merge_nodes (mnodes:collections.abc.Collection[__main__.MDDNode],
>                       mlayer:int, nsfun:collections.abc.Callable[[collections.
>                       abc.Collection[collections.abc.Hashable],int],collection
>                       s.abc.Hashable], adinfun:Optional[collections.abc.Callab
>                       le[[__main__.MDDArcData,collections.abc.Hashable,collect
>                       ions.abc.Hashable,int],__main__.MDDArcData]]=None, adout
>                       fun:Optional[collections.abc.Callable[[__main__.MDDArcDa
>                       ta,collections.abc.Hashable,collections.abc.Hashable,int
>                       ],__main__.MDDArcData]]=None)

*Merge specified nodes into a new supernode, and modify arcs appropriately.

NOTE: All nodes to be merged must be located on the same layer.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| mnodes | Collection |  | nodes to be merged together |
| mlayer | int |  | layer containing all nodes to be merged |
| nsfun | Callable |  | nsfun(slist,j) returns the node state resulting from merging node states in 'slist' in layer 'j' |
| adinfun | Optional | None | adinfun(d,os,ms,j) returns the adjusted data of an arc with data 'd', old head node state 'os', and new head node (i.e., merged supernode in layer 'j') state 'ms';<br>if adinfun is None (default), the original arc data is used |
| adoutfun | Optional | None | adoutfun(d,os,ms,j) returns the adjusted data of an arc with data 'd', old tail node state 'os', and new tail node (i.e., merged supernode in layer 'j') state 'ms';<br>if adoutfun is None (default), the original arc data is used |
| **Returns** | **MDDNode** |  | **new merged supernode** |

You can merge nodes in an MDD to reduce its size. Depending on the merging rule, this will typically result in an relaxation or restriction of the original MDD. The basic `merge_nodes()` function is configured based on node state functions, and is quite flexible. However, it can only merge nodes on the same layer.

In [None]:
%%ipytest
def test_merge_nodes() -> None:
    mymdd = MDD()
    mymdd.append_new_layers()
    n0 = MDDNode(0, 0)
    mymdd.add_node(n0)
    mymdd.append_new_layers()
    n1 = MDDNode(1, 1)
    n2 = MDDNode(1, 2)
    n3 = MDDNode(1, 3)
    mymdd.add_nodes([n1, n2, n3])
    mymdd.append_new_layers()
    n4 = MDDNode(2, 4)
    mymdd.add_node(n4)
    all_arcs = {
        MDDArc(1, n0, n1): MDDArcData(1),
        MDDArc(2, n0, n2): MDDArcData(2),
        MDDArc(3, n0, n3): MDDArcData(3),
        MDDArc(14, n1, n4): MDDArcData(-1),
        MDDArc(24, n2, n4): MDDArcData(-2),
        MDDArc(34, n3, n4): MDDArcData(-3),
    }
    mymdd.add_arcs(all_arcs)
    print(mymdd)
    print()
    mymdd.merge_nodes([n1, n2, n3], 1, lambda ns, l: max(s for s in ns))
    print(mymdd)
    assert len(list(mymdd.all_nodes())) == 3
    assert {a.label for a in mymdd.all_outgoing_arcs()} == {1, 2, 3, 14, 24, 34}

test_merge_nodes()

== MDD (mdd, 2 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1), N_1(2), N_1(3)
L2: N_2(4)
# (Outgoing) Arcs
A(1:N_0(0),N_1(1)):<1>, A(2:N_0(0),N_1(2)):<2>, A(3:N_0(0),N_1(3)):<3>, A(14:N_1(1),N_2(4)):<-1>, A(24:N_1(2),N_2(4)):<-2>, A(34:N_1(3),N_2(4)):<-3>

== MDD (mdd, 2 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(3)
L2: N_2(4)
# (Outgoing) Arcs
A(1:N_0(0),N_1(3)):<1>, A(2:N_0(0),N_1(3)):<2>, A(3:N_0(0),N_1(3)):<3>, A(14:N_1(3),N_2(4)):<-1>, A(24:N_1(3),N_2(4)):<-2>, A(34:N_1(3),N_2(4)):<-3>
[32m.[0m[32m                                                                                            [100%][0m


In [None]:
%%ipytest
#| hide
def test_merge_one_node() -> None:
    mymdd = MDD()
    mymdd.append_new_layers()
    mymdd.add_node(MDDNode(0, 0))
    mymdd.append_new_layers()
    n1 = MDDNode(1, 1)
    n2 = MDDNode(1, 2)
    mymdd.add_nodes([n1, n2])
    mymdd.merge_nodes([n1], 1, lambda ns, l: max(s for s in ns))
    print(mymdd)
    assert len(list(mymdd.all_nodes())) == 3
    assert set(mymdd.allnodes_in_layer(1)) == {n1, n2}

test_merge_one_node()

== MDD (mdd, 1 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(2), N_1(1)
# (Outgoing) Arcs

[32m.[0m[32m                                                                                            [100%][0m


While `MDD` does not explicitly forbid "skip arcs" (arcs that connect nodes in non-adjacent layers), most functionality is designed assuming there are no skip arcs.

In [None]:
%%ipytest
def test_merge_skip_arcs_fails() -> None:
    mymdd = MDD()
    mymdd.append_new_layers(4)
    n0 = MDDNode(0, 0)
    n1 = MDDNode(1, 1)
    n2 = MDDNode(2, 2)
    n3 = MDDNode(3, 3)
    mymdd.add_nodes([n0, n1, n2, n3])
    mymdd.add_arcs({MDDArc(1, n0, n1): MDDArcData(1), MDDArc(2, n1, n2): MDDArcData(2), MDDArc(2, n2, n3): MDDArcData(2)})
    print(mymdd)
    print()
    with pytest.raises(KeyError):
        mymdd.merge_nodes([n1, n2], 1, lambda ns, l: max(ns))
        print(mymdd)

test_merge_skip_arcs_fails()

== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1)
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(1:N_0(0),N_1(1)):<1>, A(2:N_1(1),N_2(2)):<2>, A(2:N_2(2),N_3(3)):<2>

[32m.[0m[32m                                                                                            [100%][0m


In [None]:
%%ipytest
def test_merge_skip_arcs_works_sometimes() -> None:
    mymdd = MDD()
    mymdd.append_new_layers(4)
    n0 = MDDNode(0, 0)
    n1 = MDDNode(1, 1)
    n2 = MDDNode(2, 2)
    n3 = MDDNode(3, 3)
    mymdd.add_nodes([n0, n1, n2, n3])
    mymdd.add_arcs({
        MDDArc("a", n0, n1): MDDArcData(1),
        MDDArc("b", n0, n2): MDDArcData(2),
        MDDArc("c", n1, n3): MDDArcData(2),
        MDDArc("d", n2, n3): MDDArcData(1),
    })
    print(mymdd)
    print()
    mymdd.merge_nodes([n1, n2], 2, lambda ns, l: max(ns))
    print(mymdd)
    assert len(list(mymdd.allnodes_in_layer(1))) == 0
    assert {a.label for a in mymdd.all_outgoing_arcs()} == {"a", "b", "c", "d"}

test_merge_skip_arcs_works_sometimes()

== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: N_1(1)
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(a:N_0(0),N_1(1)):<1>, A(b:N_0(0),N_2(2)):<2>, A(c:N_1(1),N_3(3)):<2>, A(d:N_2(2),N_3(3)):<1>

== MDD (mdd, 3 layers) ==
# Nodes
L0: N_0(0)
L1: 
L2: N_2(2)
L3: N_3(3)
# (Outgoing) Arcs
A(a:N_0(0),N_2(2)):<1>, A(b:N_0(0),N_2(2)):<2>, A(c:N_2(2),N_3(3)):<2>, A(d:N_2(2),N_3(3)):<1>
[32m.[0m[32m                                                                                            [100%][0m


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()