## Load data

In [35]:
import numpy as np
import pandas as pd
from bigtree import Node, dataframe_to_tree_by_relation, postorder_iter, levelorder_iter

In [36]:
path = "data/tree.txt"
with open(path) as f:
    seq_length = int(f.readline().strip('\n'))
    df = pd.read_csv(f, header=None)
df.rename(columns={0: 'parent_id', 1: 'distance', 2: 'seq'}, inplace=True)
df

Unnamed: 0,parent_id,distance,seq
0,707,0.026966,UUCGGUGGUUAUAGCGGUGGGGACACGCCCGGUCCCAUUCGAACCC...
1,708,0.008823,UUCGGUGGUCACAGCGGUGGGGAAACGCCCGGUCCCAUUCGAACCC...
2,709,0.026650,UUCGGUGGUAAUAGCGGUGGGGAAACGCCCGGUCCCAUUCGAACCC...
3,712,0.121237,-UCGGUGGCCAUAGCAGCAGGGAA-CGCCCGGACCCAUUCGAACCC...
4,713,0.017874,UUCGGUGGUUUUAGCGUCAGGGAAACGCCCGGUCCCAUUCGAACCC...
...,...,...,...
1405,1404,0.008787,
1406,1402,0.018323,
1407,1406,0.026961,
1408,1407,0.063071,


In [37]:
df['child_id'] = df.index.map(str)

In [38]:
df['parent_id'] = df.parent_id.map(str)

In [39]:
root = dataframe_to_tree_by_relation(df, child_col='child_id', parent_col='parent_id', attribute_cols=['distance', 'seq'])

## Better ordering

Here we sort the nodes according to their co-height (heght of the subtree below it). We still get a topological ordering, and all leaves will be listed first.

In [6]:
from bigtree import find_full_path

In [7]:
aux = root.copy()
target_order = []
while len(aux.children) != 0:
    leaves = list(aux.leaves)
    target_order.extend((node.path_name for node in sorted(list(aux.leaves), key=lambda node: int(node.name))))
    for leaf in leaves:
        leaf.parent.children = tuple(filter(lambda child: child is not leaf, leaf.parent.children))

44


In [192]:
old_to_new = {}
for n, node_path in enumerate(target_order):
    node = find_full_path(root, node_path)
    old_to_new[node.name] = n

In [193]:
save_path = 'data/tree_topological.csv'
with open(save_path, 'w') as file:
    file.write(f"parent,left,right,distance,sequence\n")
    for n, node_path in enumerate(target_order):
        node = find_full_path(root, node_path)
        children = [str(old_to_new[x.name]) for x in node.children]
        while len(children) < 2:
            children.append('')
        if len(children) == 3:
            file.write(f"{','.join(children)},{getattr(node, 'distance', '')},{getattr(node, 'seq', '')}\n")
        else:    
            file.write(f"{old_to_new[node.parent.name]},{','.join(children)},{getattr(node, 'distance', '')},{getattr(node, 'seq', '')}\n")

## DFS postorder

Topological ordering, but has a drawback that some non-terminal nodes come in front of some leaves.

In [6]:
old_to_new = {}
#new_to_old = {}

In [7]:
for n, node in enumerate(postorder_iter(root)):
    old_to_new[node.name] = n
    #new_to_old[n] = node.name

In [9]:
start = root.children[0]

In [10]:
save_path = 'data/tree_preprocessed.csv'
with open(save_path, 'w') as file:
    file.write(f"parent,left,right,distance,sequence\n")
    for n, node in enumerate(postorder_iter(start)):
        children = [str(old_to_new[x.name]) for x in node.children]
        while len(children) < 2:
            children.append('')
        if len(children) == 3:
            file.write(f"{','.join(children)},{getattr(node, 'distance', '')},{getattr(node, 'seq', '')}\n")
        else:    
            file.write(f"{old_to_new[node.parent.name]},{','.join(children)},{getattr(node, 'distance', '')},{getattr(node, 'seq', '')}\n")

## Pytorch implementation?

`DIM`: the dimension of log_p for each node (for example, `DIM = 16` for residue pairs).

__We assume that all nodes in the tree either have 2 children, or none. (*)__
<br>We assume also that the root has at most 3 children (but this is not important).

We are going to traverse the tree in the BFS order. It is important that the order of children for each node is fixed. We can then encode the path from the root to each node in the tree as a number, for example, '010100' or '210111'. (_or, rather, as a string of digits_)

Let's call the set of nodes with paths of length $n$ the $n$-th level of the tree. If we list the nodes at each level in the order of increasing path codes, and taking into account our assumption (*), it then holds that the $2i$'th and $2i+1$'st nodes in this order are precisely the children of the $i$'th _non-terminal_ node on the previous level.

In [82]:
#Reusing the function from double_ems
soft_nuc_mapping = {'A': [1.0, 0.0, 0.0, 0.0],
                    'C': [0.0, 1.0, 0.0, 0.0],
                    'G': [0.0, 0.0, 1.0, 0.0], 
                    'T': [0.0, 0.0, 0.0, 1.0],
                    'U': [0.0, 0.0, 0.0, 1.0],
                    '-': [0.25, 0.25, 0.25, 0.25],
                    '.': [0.25, 0.25, 0.25, 0.25],
                    'R': [0.5, 0.0, 0.5, 0.0],
                    'Y': [0.0, 0.5, 0.0,0.5],
                    'S': [0.0, 0.5, 0.5, 0.0],
                    'W': [0.5, 0.0, 0.0, 0.5],
                    'K': [0.0, 0.0, 0.5, 0.5],
                    'M': [0.5, 0.5, 0.0, 0.0],
                    'B': [0.0, 1/3, 1/3, 1/3],
                    'D': [1/3, 0.0, 1/3, 1/3],
                    'H': [1/3, 1/3, 0.0, 1/3],
                    'V': [1/3, 1/3, 1/3, 0.0],
                    'N': [0.25, 0.25, 0.25, 0.25]}

In [71]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from itertools import compress

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DIM = 4

In [32]:
torch.arange(10)[1::2]

tensor([1, 3, 5, 7, 9])

In [188]:
class TreeLevel(nn.Module):
    """If the level contains n nodes, among which there are m non-terminals, then:
    Parameters:
        x: combined children's contributions from the previous tree level, as a tensor of shape (m, DIM).
(?)     exp_rate: matrix of transition probabilities, as a torch tensor of shape (DIM, DIM)
    Initialization:
        Data (buffers):
            distance: distances to parents for all nodes at this level, as a tensor of shape (n, )
            log_p_leaves: the log_p for all leaf nodes at this level, as a tensor of shape (n-m, DIM)
        is_leaf: the mask of leaf nodes, tensor of booleans of shape (n, )
    Output: combined contributions from this level, as a tensor of shape (n/2, )
    """
    def __init__(self, distance, log_p_leaves, is_leaf, exp_rate: nn.Parameter):
        super().__init__()

        self.exp_rate = exp_rate
        ## TODO optimize: use view? assign leaves in __init__?
        #self.log_p = torch.empty((n_nodes, DIM), dtype=float)
        
        self.register_buffer("distance", distance)
        self.register_buffer("is_leaf", is_leaf)
        self.register_buffer("log_p_leaves", log_p_leaves)
        
    def forward(self, x): 
        # For batches, matrix_exp is applied to the last two dimensions
        log_transition = torch.log(
            torch.linalg.matrix_exp(self.distance[:, None, None] * self.exp_rate[None, :, :]))
        
        n_nodes = self.distance.shape[0]
        # Will this work?
        log_p = torch.empty((n_nodes, DIM), dtype=float, device=x.device)
        if self.is_leaf.any():
            log_p[self.is_leaf] = self.log_p_leaves
        if ~(self.is_leaf).any():
            log_p[~self.is_leaf] = x
        
        contrib = torch.logsumexp(log_p[:, None, :] + log_transition, dim=2)
        result = contrib[::2] + contrib[1::2]
        return result

In [189]:
class TopLevel(nn.Module):
    def __init__(self, distance, log_p_leaves, is_leaf, exp_rate: nn.Parameter):
        super().__init__()
        self.exp_rate = exp_rate
        ## TODO optimize: use view? assign leaves in __init__?
        #self.log_p = torch.empty((n_nodes, DIM), dtype=float)
        
        self.register_buffer("distance", distance)
        self.register_buffer("is_leaf", is_leaf)
        self.register_buffer("log_p_leaves", log_p_leaves)
    def forward(self, x):
        log_transition = torch.log(
            torch.linalg.matrix_exp(self.distance[:, None, None] * self.exp_rate[None, :, :]))

        n_nodes = self.distance.shape[0]
        log_p = torch.empty((n_nodes, DIM), dtype=float, device=x.device)
        if self.is_leaf.any():
            log_p[self.is_leaf] = self.log_p_leaves
        if ~(self.is_leaf).any():
            log_p[~self.is_leaf] = x

        contrib = torch.logsumexp(log_p[:, None, :] + log_transition, dim=2)
        result = torch.sum(contrib, dim=0)
        return result

In [190]:
rate_example = (
    torch.eye(DIM, dtype=float) 
    - 1.0 / (DIM - 1) * (
        torch.ones((DIM, DIM), dtype=float) 
        - np.eye(DIM, dtype=float)))
exp_rate = nn.Parameter(torch.exp(rate_example))

In [191]:
def bfs_generator(root):
    this_level = [root]
    while this_level:
        yield this_level
        next_level = []
        for node in this_level:
            next_level.extend(node.children)
        this_level = next_level

In [192]:
start = root.children[0]
bfs_iter = bfs_generator(start)

In [193]:
next(bfs_iter)
root_level = next(bfs_iter)

In [195]:
def root_level_to_module(root_level: list):
    distance = torch.tensor([node.distance for node in root_level])
    is_leaf = torch.tensor([node.is_leaf for node in root_level])
    sequences_from_leaves = [node.seq for node in compress(root_level, is_leaf)]
        
    #Should save this and iterate over it later, for now let's just take the first residue
    first_residues = [seq[0] for seq in sequences_from_leaves]
    if first_residues:
        log_p_leaves = torch.vstack([torch.log(torch.tensor(soft_nuc_mapping[R], dtype=torch.float64)) for R in first_residues])
    else:
        log_p_leaves = torch.empty((0, DIM), dtype=torch.float64)
    return TopLevel(distance, log_p_leaves, is_leaf, exp_rate)

In [196]:
def level_to_module(level: list):
    distance = torch.tensor([node.distance for node in level])
    is_leaf = torch.tensor([node.is_leaf for node in level])
    sequences_from_leaves = [node.seq for node in compress(level, is_leaf)]
    
    #Should save this and iterate over it later, for now let's just take the first residue
    first_residues = [seq[0] for seq in sequences_from_leaves]

    if first_residues:
        log_p_leaves = torch.vstack([torch.log(torch.tensor(soft_nuc_mapping[R], dtype=torch.float64)) for R in first_residues])
    else:
        log_p_leaves = torch.empty((0, DIM), dtype=torch.float64)
    return TreeLevel(distance, log_p_leaves, is_leaf, exp_rate)

In [197]:
class Tree(nn.Module):
    def __init__(self, top_layer, intermediate_layers:nn.ModuleList):
        super().__init__()
        self.top_layer = top_layer
        self.layers = intermediate_layers

    def forward(self, x):
        for layer in reversed(self.layers):
            x = layer(x)
        x = self.top_layer(x)
        return x

In [198]:
top_layer = root_level_to_module(root_level)

In [199]:
#example_input = torch.rand((3,DIM), dtype=torch.float64)
#top_layer.forward(example_input)

In [200]:
tree = Tree(top_layer, nn.ModuleList((level_to_module(level) for level in bfs_iter)))

In [201]:
result = tree.forward(torch.empty((0,)))

In [106]:
#example_input = torch.rand((5,DIM), dtype=torch.float64)

In [109]:
#level_to_module(level).forward(example_input)

tensor([[-3.3051, -3.1501, -3.2613,  0.5986],
        [ 1.1588,  1.0472,  1.2382,  0.8457],
        [ 0.9705,  0.9271,  0.1356,  0.5694]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
#tree = nn.ModuleList(generate_layers)

TODO:
- The result is wrong, need to debug it
- Replace all doubles with float32

Questions:
- Do we need to pass the rate matrix through the layer to save its gradient? (We need it to accumulate the gradient, and also to be in scope for the forward method, e.g. on GPU, but should we copy it all the time?)
- - No, this will just be an nn.Parameter passed to the model during eval.