In [1]:
from typing import List, Dict, Any
from pathlib import Path
from copy import deepcopy
from tqdm.auto import tqdm
import cassiopeia as cas
import newick
import pickle
import pandas as pd

In [2]:
cas.__version__  # 3e85462cbc10b12352d44435b2f6bed4384a46cb

'0.0.1'

In [3]:
root = Path("/home/marius.lange/server_home/tumors/")
td = root / "KPTracer-Data" / "trees"

In [4]:
toi = ["3726 NT T2", "3435 NT T4", "3513 NT T3", "3432 NT T1", "3726 NT T1", "3726 NT T2",
       "3435 NT T4", "3730 NT T1", "3435 NT T1", "3726 NT T2"]
toi = sorted({t.replace(" ", "_") for t in toi})
toi, len(toi)

(['3432_NT_T1',
  '3435_NT_T1',
  '3435_NT_T4',
  '3513_NT_T3',
  '3726_NT_T1',
  '3726_NT_T2',
  '3730_NT_T1'],
 7)

In [5]:
def load_tree(t: str) -> cas.data.CassiopeiaTree:
    with open(td / f"{t}_priors.pkl", "rb") as fin:
        priors = pickle.load(fin)
    cmat = pd.read_csv(td / f"{t}_character_matrix.txt", sep='\t', index_col=0)
    cmat[cmat == '-'] = '-1'
    cmat = cmat.astype(int)

    tree = cas.data.CassiopeiaTree(cmat, tree=str(td / f"{t}_tree.nwk"),
                                   missing_state_indicator=-1, priors=priors)
    return tree

In [6]:
def reconstruct_branch_lenghts(toi: List[str],
                               eles: List[str] = ['mle', 'custom', 'custom_mam']) -> Dict[str, Any]:
    def custom(tree, mam=False):
        tree.reconstruct_ancestral_characters()
        for edge in tree.depth_first_traverse_edges(source = tree.root):
            mutations = tree.get_mutations_along_edge(edge[0], edge[1], treat_missing_as_mutations=mam)
            tree.set_branch_length(edge[0], edge[1], min(1, len(mutations)))
    
        return tree
    
    res = {}
    for t_id in tqdm(toi):
        for ele in eles:
            tree = load_tree(t_id)
            if ele == 'mle':
                tree.reconstruct_ancestral_characters()
                estim = cas.tools.branch_length_estimator.IIDExponentialMLE()
                estim.estimate_branch_lengths(tree)
            elif ele == 'custom':
                tree = custom(tree, mam=False)
            elif ele == 'custom_mam':
                tree = custom(tree, mam=True)
            else:
                raise NotImplementedError(ele)
            res[f"{t_id}_{ele}"] = tree
            
    return res

In [None]:
data = reconstruct_branch_lenghts(toi)

  0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
with open("res.pkl", "wb") as fout:
    pickle.dump(res, fout, protocol=4)