In [4]:
from typing import Tuple, Union, Optional
 
from scipy.sparse import issparse, csr_matrix, dok_matrix
from moscot.backends.ott import LinearOutput, LRLinearOutput
from moscot.problems.base import BirthDeathProblem
from anndata import AnnData
from ott.core import sinkhorn
from ott.geometry import pointcloud
import numpy as np
import jax.numpy as jnp
import numpy.typing as npt
import networkx as nx
 
Output = Union[LinearOutput, LRLinearOutput]
 
 
def distance_between_pushed_masses(
   gex_data_source: npt.ArrayLike,
   gex_data_target: npt.ArrayLike,
   output: Union[npt.ArrayLike, csr_matrix, BirthDeathProblem],
   true_coupling: Union[npt.ArrayLike, csr_matrix],
   eps: float = 0.1,
   seed: Optional[int] = None,
   n_samples: Optional[int] = None,
) -> float:
   source_space_cost = _distance_pushed_masses(
       gex_data_source, output, true_coupling, forward=False, eps=eps, random=False, seed=seed, n_samples=n_samples
   )
   target_space_cost = _distance_pushed_masses(
       gex_data_target, output, true_coupling, forward=True, eps=eps, random=False, seed=seed, n_samples=n_samples
   )
   independent_coupling_source_space_cost = _distance_pushed_masses(
       gex_data_source, output, true_coupling, forward=False, random=True, seed=seed, n_samples=n_samples
   )
   independent_coupling_target_space_cost = _distance_pushed_masses(
       gex_data_target, output, true_coupling, forward=True, random=True, seed=seed, n_samples=n_samples
   )
   source_error = source_space_cost / independent_coupling_source_space_cost
   target_error = target_space_cost / independent_coupling_target_space_cost
   mean_error = (source_error + target_error) / 2
   return mean_error
 
 
def _get_masses_moscot(
   i: int,
   output: Union[npt.ArrayLike, csr_matrix, BirthDeathProblem],
   true_coupling: Union[npt.ArrayLike, csr_matrix],
   n: int,
   m: int,
   forward: bool,
   random: bool,
) -> Tuple[npt.ArrayLike]:
   if forward:
       pushed_mass_true = true_coupling[i, :]
       pushed_mass_true /= pushed_mass_true.sum()
       mass = np.zeros(n, dtype="float64")
       mass[i] = 1
       pushed_mass = output.solution.push(mass).squeeze()
       weight_factor = jnp.sum(pushed_mass)
       if random:
           pushed_mass = output.b[:, -1]
           pushed_mass /= pushed_mass.sum()
       else:
           pushed_mass /= weight_factor
   else:
       pushed_mass_true = true_coupling.T[i, :]
       pushed_mass_true /= pushed_mass_true.sum()
       mass = np.zeros(m, dtype="float64")
       mass[i] = 1
       pushed_mass = output.solution.pull(mass).squeeze()
       weight_factor = jnp.sum(pushed_mass)
       if random:
           pushed_mass = output.a[:, -1]
           pushed_mass /= pushed_mass.sum()
       else:
           pushed_mass /= weight_factor
   return pushed_mass_true.astype("float64"), pushed_mass, weight_factor
 
 
def _get_masses_ndarray(
   i: int,
   output: Union[npt.ArrayLike, csr_matrix],
   true_coupling: Union[npt.ArrayLike, csr_matrix],
   forward: bool,
   random: bool,
) -> Tuple[npt.ArrayLike, ...]:
   if forward:
       pushed_mass_true = true_coupling[i, :]
       pushed_mass_true /= pushed_mass_true.sum()
       weight_factor = output[i, :].sum()
       if random:
           pushed_mass = output.sum(0)
           pushed_mass /= pushed_mass.sum()
       else:
           pushed_mass = output[i, :]
           pushed_mass /= weight_factor
   else:
       pushed_mass_true = true_coupling.T[i, :]
       pushed_mass_true /= pushed_mass_true.sum()
       weight_factor = output.T[i, :].sum()
       if random:
           pushed_mass = output.sum(1)
           pushed_mass /= pushed_mass.sum()
       else:
           pushed_mass = output.T[i, :]
           pushed_mass /= weight_factor
   return pushed_mass_true.astype("float64"), pushed_mass.astype("float64"), weight_factor
 
 
def _distance_pushed_masses(
    gex_data: npt.ArrayLike,
    output: Union[npt.ArrayLike, csr_matrix, BirthDeathProblem],
    true_coupling: Union[npt.ArrayLike, csr_matrix],
    forward: bool,
    eps: float = 0.5,
    random: bool = False,
    seed: Optional[int] = None,
    n_samples: Optional[int] = None,
    ) -> float:
    rng = np.random.RandomState(seed=seed)
    n, m = output.shape if isinstance(output, np.ndarray) else output.solution.shape
    wasserstein_d = 0
    total_weight = 0
    if n_samples is None:
        samples = range(n if forward else m)
    else:
        samples = rng.choice(n if forward else m, size=n_samples)
    for i in samples:
        if isinstance(output, np.ndarray):
            pushed_mass_true, pushed_mass, weight_factor = _get_masses_ndarray(
                i, output, true_coupling, forward, random
            )
        elif isinstance(output, BirthDeathProblem):
            pushed_mass_true, pushed_mass, weight_factor = _get_masses_moscot(
                i, output, true_coupling, n, m, forward, random
            )
        else:
            raise TypeError(f"Return type is {type(output)}")
        if issparse(pushed_mass_true):
            pushed_mass_true = np.squeeze(pushed_mass_true.A)
        geom = pointcloud.PointCloud(gex_data, gex_data, epsilon=eps, scale_cost="mean")
        out = sinkhorn.sinkhorn(geom, pushed_mass_true, pushed_mass, max_iterations=1e7)
        wasserstein_d += float(out.reg_ot_cost) * weight_factor
        total_weight = +weight_factor
        del geom
        del out

    return wasserstein_d / (len(samples) * len(pushed_mass) * total_weight)

 
def get_leaf_descendants(tree, node):
    """
    copied from https://github.com/aforr/LineageOT/blob/8c66c630d61da289daa80e29061e888b1331a05a/lineageot/inference.py#L657

    Returns a list of the leaf nodes of the tree that are
    descendants of node
    """
    if tree.out_degree(node) == 0:
        return [node]
    else:
        children = tree.successors(node)
        leaf_descendants = []
        for child in children:
            leaf_descendants = leaf_descendants + get_leaf_descendants(tree, child)
        return leaf_descendants
    return
 
 
def get_true_coupling(early_tree, late_tree):
    """
    adapted from https://github.com/aforr/LineageOT/blob/8c66c630d61da289daa80e29061e888b1331a05a/lineageot/inference.py#L657

    Returns the coupling between leaves of early_tree and their descendants in
    late_tree. Assumes that early_tree is a truncated version of late_tree
    The marginal over the early cells is uniform; if cells have different
    numbers of descendants, the marginal over late cells will not be uniform.
    """
    num_cells_early = len(get_leaves(early_tree)) - 1
    num_cells_late = len(get_leaves(late_tree)) - 1

    coupling = dok_matrix((num_cells_early, num_cells_late))

    cells_early = get_leaves(early_tree, include_root = False)


    for cell in cells_early:
        parent = next(early_tree.predecessors(cell))
        late_tree_cell = None
        for child in late_tree.successors(parent):
            if late_tree.nodes[child]['cell'].seed == early_tree.nodes[cell]['cell'].seed:
                late_tree_cell = child
                break
        if late_tree_cell == None:
            raise ValueError("A leaf in early_tree does not appear in late_tree. Cannot find coupling." +
                            "\nCheck whether either tree has been modified since truncating.")
        descendants = get_leaf_descendants(late_tree, late_tree_cell)
        coupling[cell, descendants] = 1/(num_cells_early*len(descendants))

    return coupling
 
def newick2digraph(tree: str) -> nx.DiGraph:
    def trav(clade, prev: Any, depth: int) -> None:
        nonlocal cnt
        if depth == 0:
            name = "root"
        else:
            name = clade.name
            if name is None:
                name = cnt
                cnt -= 1
            else:
                name = int(name[1:]) - 1

        G.add_node(name, node_depth=depth)
        if prev is not None:
            G.add_edge(prev, name)

        for c in clade.clades:
            trav(c, name, depth + 1)

    G = nx.DiGraph()
    cnt = -1
    tree = Phylo.read(io.StringIO(tree), "newick")
    trav(tree.clade, None, 0)

    start = max([n for n in G.nodes if n != "root"]) + 1
    for n in list(nx.dfs_preorder_nodes(G)):
        if n == "root":
            pass
        if is_leaf(G, n):
            continue

        assert start not in G.nodes
        G = nx.relabel_nodes(G, {n: start}, copy=False)
        start += 1

    return G
 
 
def cut_at_depth(G: nx.DiGraph, *, max_depth: Optional[int] = None) -> nx.DiGraph:
    if max_depth is None:
        return deepcopy(G)
    selected_nodes = [n for n in G.nodes if G.nodes[n]["node_depth"] <= max_depth]
    G = deepcopy(G.subgraph(selected_nodes).copy())
 
    # relabel because of LOT
    leaves = [n for n in G.nodes if not len(list(G.successors(n)))]
    for new_name, n in enumerate(leaves):
        G = nx.relabel_nodes(G, {n: new_name}, copy=False)
    return G
 
def get_true_coupling(
    adata: AnnData,
    depth: int,
    ) -> csr_matrix:
    adata = sc.read(fpath)
    tree = adata.uns["tree"]
    meta = adata.obs
  
   metadata = [meta.iloc[nid].to_dict() for nid in range(len(adata))]
   G = newick2digraph(tree)
   G = annotate(G, cell_arr_adata, metadata, ttp=ttp)
   trees = {"early": cut_at_depth(G, max_depth=depth), "late": cut_at_depth(G)}
  
   true_coupling = get_true_coupling(true_trees["early"], true_trees["late"])
 
   return true_coupling

def annotate(
    G: nx.DiGraph,
    meta: List[Dict[str, Any]],
    ttp: int = 100,
    ) -> nx.DiGraph:
    G = G.copy()
    n_leaves = len([n for n in G.nodes if not len(list(G.successors(n)))])
    assert (n_leaves & (n_leaves - 1)) == 0, f"{n_leaves} is not power of 2"
    max_depth = int(np.log2(n_leaves))

    n_expected_nodes = 2 ** (max_depth + 1) - 1
    assert len(G) == n_expected_nodes, "graph is not a full binary tree"

    for nid in G.nodes:
        depth = G.nodes[nid]["node_depth"]
        metadata = {
            **meta[nid],  # contains `depth`, which is different from `node_depth`
            "cell": cell_arr_data[nid],
            "nid": nid,
            "time": depth * ttp,
            "time_to_parent": ttp,
        }
        G.nodes[nid].update(metadata)
    return nx.relabel_nodes(G, {n_leaves: "root"}, copy=False)


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 260)

In [1]:
adata = sc.read("/lustre/groups/ml01/datasets/projects/2022-02-25_moscot/moscot-lineage_reproducibility/notebooks/analysis_notebooks/tedsim/data_generation/adatas_large/adata_1024.h5ad")

NameError: name 'sc' is not defined

In [None]:
adata