# Lab 8: Implementation of the sum product algorithm
Today, we are working with the end goal of implementing a sum-product algorithm for bipartite trees. We provide you some helper classes and functions, but the final implementation will be your task.

This lab was implemented by our lab colleague Senbai Kang.

In [1]:
import _io
import os
import random
import numpy as np
from typing import Optional, List, Set, Tuple, Dict
from numpy.typing import NDArray

## Helper classes & functions

In [2]:
class Node:
    def __init__(
        self, 
        label: int, 
        branch_length: float = 1.0, 
        name: str = None, 
        parent: 'Node' = None,
        children: List['Node'] = None
    ):
        self._label: int = label
        self._branch_length: float = branch_length
        self._name: str = name
        self._parent: 'Node' = parent
        self._children: List['Node'] = []
        if children is not None:
            self._children.extend(children)
        self._states_num: int = 0
        """
        TO BE COMPLETED.
        
        Add attributes of `Node` to store the forward and backward messages.
        """

    def __eq__(self, o: object) -> bool:
        if o is None: return False
        if self._label != o._label: return False
        return True if self._name == o._name else False

    def __ne__(self, o: object) -> bool:
        return not self.__eq__(o)
    
    def __lt__(self, o: object) -> bool:
        if self._label < o._label:
            return True
        elif self._label > o._label:
            return False
        else:
            return self._name < o._name
        
    def __le__(self, o: object) -> bool:
        if self._label < o._label:
            return True
        elif self._label > o._label:
            return False
        else:
            return self._name <= o._name
    
    def __gt__(self, o: object) -> bool:
        if self._label > o._label:
            return True
        elif self._label < o._label:
            return False
        else:
            return self._name > o._name
        
    def __ge__(self, o: object) -> bool:
        if self._label > o._label:
            return True
        elif self._label < o._label:
            return False
        else:
            return self._name >= o._name
    
    def __hash__(self) -> int:
        return hash((self._label, self._name))
    
    def set_label(self, label: int) -> None:
        self._label = label

    def get_label(self) -> int:
        return self._label

    def set_parent(self, parent: 'Node') -> None:
        if parent is None: return None
        self._parent = parent

    def get_parent(self) -> 'Node':
        return self._parent

    def get_child(self, child_index: int) -> Optional['Node']:
        if child_index + 1 > len(self._children):
            return None
        return self._children[child_index]
    
    def get_children(self) -> List['Node']:
        return self._children

    def get_branch_length(self) -> float:
        return self._branch_length

    def set_node_name(self, name: str) -> None:
        self._name = name

    def get_node_name(self) -> str:
        return self._name

    def add_child(self, child: 'Node') -> None:
        if child is None: return None
        if child not in self._children:
            child.set_parent(self)
            self._children.append(child)
    
    def add_children(self, children: List['Node']) -> None:
        if children is None or len(children) == 0: return None
        for child in children:
            self.add_child(child)

    def is_leaf(self) -> bool:
        return len(self._children) == 0

    def is_root(self) -> bool:
        if self._parent is None:
            self._branch_length = 0.0
            return True
        return False

    def write_in_nexus(self, fh: _io.TextIOWrapper) -> None:
        if self.is_leaf():
            fh.write('{:d}:{:f}'.format(self._label, self._branch_length))
        else:
            fh.write('(')
            for index in range(len(self._children)):
                self._children[index].write_in_nexus(fh)
                if index < len(self._children) - 1:
                    fh.write(',')
            fh.write('):{0:f}'.format(self._branch_length))


class Tree:
    def __init__(self, root: Node = None):
        self._root: Node = root
        self._nodes_num: int = 0
        self._node_map: Dict[int, Node] = {}
        self._internal_nodes: List[Node] = []
        self._leaf_nodes: List[Node] = []
        self._rng = None # Random number generator
        self._states_num: int = None # Number of states
        self._trans_prob_mtx: NDArray = None # Transition probability matrix
        self._freq: NDArray = None # Discrete probability distriibution of root

    def get_root(self) -> Node:
        return self._root

    def set_root(self, root: Node) -> None:
        self._root = root

    def add_leaf_node(self, node: Node) -> None:
        if node is None or not node.is_leaf(): return None
        self._leaf_nodes.append(node)
        
    def get_leaf_nodes(self) -> List[Node]:
        return self._leaf_nodes

    def add_internal_node(self, node: Node) -> None:
        if node is None or node.is_leaf(): return None
        self._internal_nodes.append(node)
        
    def get_internal_nodes(self) -> List[Node]:
        return self._internal_nodes
    
    def get_node_labels(self) -> Optional[Dict[str, List[int]]]:
        if self._root is None: return None
        return {
            'root': [self._root.get_label()],
            'leaf': [i.get_label() for i in self.get_leaf_nodes()],
            'internal': [i.get_label() for i in self.get_internal_nodes()]
        }
    
    def build_node_map(self) -> None:
        self._node_map.clear()
        for node in [*self._leaf_nodes, *self._internal_nodes]:
            self._node_map[node.get_label()] = node
        
    def adjust_nodes_order(self) -> None:
        self._leaf_nodes = sorted(self._leaf_nodes)
        self._internal_nodes = sorted(self._internal_nodes)
        self._nodes_num = len(self._leaf_nodes) + len(self._internal_nodes)
        _adjust_nodes_label = lambda l, k: [l[i].set_label(i + k) for i in range(len(l))]
        _adjust_nodes_label(self._internal_nodes, len(self._leaf_nodes))
        _adjust_nodes_label(self._leaf_nodes, 0)
        self.build_node_map()

    def write_in_nexus(self, output: str) -> None:
        if self._root is None:
            raise ValueError('Error! The tree is empty.')

        if not os.path.exists(os.path.dirname(output)):
            os.makedirs(os.path.dirname(output))

        # Write header
        with open(output, 'w') as fh:
            fh.write('#NEXUS\n\n')

            # Write taxa block
            fh.write('Begin taxa;\n')
            fh.write('\tDimensions ntax=' + str(len(self._leaf_nodes)) + ";\n")
            fh.write('\t\tTaxlabels\n')
            for node in self._leaf_nodes:
                fh.write('\t\t\t' + node.get_node_name() + '\n')
            fh.write('\t\t\t;\n')
            fh.write('End;\n')

            # Write trees block
            fh.write('Begin trees;\n')
            fh.write('\tTranslate\n')
            for node in self._leaf_nodes:
                fh.write('\t\t{0:4d} {1:s},\n'.format(node.get_label(), node.get_node_name()))
            fh.write(';\n')
            fh.write('tree TREE1 = ')
            self._root.write_in_nexus(fh)
            fh.write(';\n')
            fh.write('End;\n')
        
    def init_rng(self, seed: int) -> None:
        self._rng = np.random.default_rng(seed)

    def __gen_rand_mtx__(self, dim: Tuple[int, int]) -> NDArray:
        ret: NDArray = self._rng.random(dim)
        return ret / np.sum(ret, axis=1)[:, None]
        
    def gen_rand_mtx(self, states_num: int) -> None:
        # Generate `self._trans_prob_mtx` and `self._freq`.
        self._states_num = states_num
        self._trans_prob_mtx = self.__gen_rand_mtx__((states_num, states_num))
        self._freq = self.__gen_rand_mtx__((1, states_num))
        
    def get_trans_prob_mtx(self) -> NDArray:
        return self._trans_prob_mtx
    
    def get_freq(self) -> NDArray:
        return self._freq


def gen_random_tree(
    leaf_nodes_num: int,
    num_children_min: int = 2,
    num_children_max: int = 2,
    branch_length_min: float = 1.0,
    branch_length_max: float = 1.0
) -> Optional[Tree]:
    if leaf_nodes_num <= 1: return None

    nodes: List[Node] = []
    available_nodes: List[Node] = [Node(
        label=i, 
        branch_length=random.uniform(branch_length_min, branch_length_max),
        name='n' + str(i)
    ) for i in range(leaf_nodes_num)]
    
    next_node_label: int = leaf_nodes_num
    while len(available_nodes) > 1:
        children: List[Node] = random.sample(
            available_nodes, 
            min(
                len(available_nodes),
                random.randint(
                    num_children_min, 
                    num_children_max
                )
            )
        )
        parent: Node = Node(next_node_label)
        next_node_label += 1
        parent.add_children(children)
        
        nodes.extend(children)
        
        available_nodes = sorted(list(set(available_nodes) - set(nodes)))
        available_nodes.append(parent)
    
    nodes.append(available_nodes[0])
    nodes = sorted(nodes)
    tree: Tree = Tree()
    
    for node in nodes:
        if node.is_leaf():
            tree.add_leaf_node(node)
        else:
            if node.is_root():
                tree.set_root(node)
            tree.add_internal_node(node)
    tree.adjust_nodes_order()
    return tree

## Task 1: Convert into a factor graph
Describe the ways to convert a tree-structured directed acyclic graph into a factor graph. To simplify the problem, here we only consider bipartite trees, i.e., each node in such a tree can only has two or none children as well as one or none parent. How many ways of convertion can you think of? Which one is better for the general implementation of the sum product algorithm?

## Task 2: Generate a random bipartite tree
Generate a random bipartite tree with your favorite seed and number of leaf nodes using `gen_random_tree`. 

In [3]:
random.seed(1024)
tree: Tree = gen_random_tree(8)
node_labels = tree.get_node_labels()
node_labels

{'root': [14],
 'leaf': [0, 1, 2, 3, 4, 5, 6, 7],
 'internal': [8, 9, 10, 11, 12, 13, 14]}

To visualize `tree`, save it to disk in *nexus* format, then load it with [this](https://itol.embl.de) website.

In [4]:
tree.write_in_nexus('./tree.nexus')

Convert your tree to a factor graph in a way which you think is proper. It might be helpful for the implementation of the sum product algorithm if you draw it down.

## Task 3: Generate a random transition probability matrix and a random discrete probability distribution for the root
For simplificity, we assume that each node on the above tree has exactly the same state space and transition probability matrix. Both the transition paobability matrix and the discrete probability distribution for the root are attributes of a `Tree` instance. 

In [5]:
tree.init_rng(1024)
tree.gen_rand_mtx(4)
print(tree.get_trans_prob_mtx())
print(tree.get_freq())

[[0.17561602 0.14070153 0.20167553 0.48200692]
 [0.26378756 0.01777078 0.37487642 0.34356524]
 [0.31319111 0.10628433 0.3870441  0.19348046]
 [0.41001416 0.18829778 0.21774498 0.18394307]]
[[0.26360287 0.32815626 0.35962737 0.0486135 ]]


## Task 4: Implementation of the sum product algorithm
Now, with the tree, transition probability matrix and discrete probability distribution available, you are ready to implement the sum product allgorithm. Our goal is to be able to compute the marginal probability of any node except for the root. Therefore, you should complete a full cycle of message passing - both forward and backward, and store the message inside each node as an attribute for future reference. The functions to be completed, named `sum_prod` and `get_marginal_prob`, are member funcitons of the `Tree` class.

Hints:
1. You might want to add other helper functions, either member functions of the provided classes or independent functions.