# LineageOT benchmark

In [1]:
# pip install git+https://github.com/aforr/LineageOT@master cvxopt

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

from ott.geometry.geometry import Geometry
from ott.geometry.pointcloud import PointCloud
from jax import numpy as jnp
import seaborn as sns

from time import perf_counter
from moscot import FusedGW, Unbalanced, GW
import pickle
import os

In [3]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ot

import lineageot.simulation as sim
import lineageot.evaluation as sim_eval
import lineageot.inference as sim_inf

from typing import Literal, Optional, Sequence, Dict
import traceback
from collections import namedtuple, defaultdict

In [4]:
bnt = namedtuple("bnt", "tmat early_cost late_cost norm_diff converged time")
stn = namedtuple("sim",
                 "sim_info "
                 "ancestor_info "
                 "rna_arrays "
                 "true_coupling "
                 "true_distances "
                 "barcode_arrays "
                 "fitted_tree_distances_early "
                 "fitted_tree_distances_late "
                 "hamming_distances_late "
                 "early_time_rna_cost "
                 "late_time_rna_cost")

In [5]:
def create_geometry(cost_matrix: np.ndarray, scale='max') -> Geometry:
    cost_matrix = jnp.array(cost_matrix)
    if scale is None:
        pass
    elif scale == 'max':
        cost_matrix /= cost_matrix.max()
        assert cost_matrix.max() == 1.0
    elif scale == 'mean':
        cost_matrix /= np.mean(cost_matrix)
    elif scale == 'median':
        cost_matrix /= np.median(cost_matrix)
    else:
        raise NotImplementedError(scale)
        
    assert (cost_matrix >= 0).all()
    return Geometry(cost_matrix=cost_matrix)


def fgw_solver(C1, C2, C12, alpha=0, epsilon=1e-2, loss_fun='square_loss',
              p=None, q=None, max_iterations=100, rtol=1e-9, atol=1e-9, verbose=False):
    
    from ot.gromov import init_matrix, gwggrad, gwloss
    
    assert 0 <= alpha <= 1, alpha
    
    if p is None:
        p = np.ones((C1.shape[0],), dtype=np.float64) / C1.shape[0]
    if q is None:
        q = np.ones((C2.shape[0],), dtype=np.float64) / C2.shape[0]

    solver = ot.sinkhorn if epsilon >= 0.1 else ot.bregman.sinkhorn_epsilon_scaling
    C1 = np.asarray(C1, dtype=np.float64)
    C2 = np.asarray(C2, dtype=np.float64)
    
    if alpha == 0:
        C12 = np.asarray(C12, dtype=np.float64)
        return solver(p, q, C12, reg=epsilon, numItermax=max_iterations, stopThr=rtol)
    if alpha == 1:
        return ot.gromov.entropic_gromov_wasserstein(C1, C2, p=p, q=q, loss_fun=loss_fun, epsilon=epsilon, tol=rtol)

    C12 = np.asarray(C12, dtype=np.float64)
    C12 = (1 - alpha) * C12
    f_val = 0
    T = np.outer(p, q)
    constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
    
    
    fmt = "{:5s}|{:12s}|{:8s}|{:8s}"
    if verbose:
        print(
            fmt.format(
                "It.",
                "Loss",
                "Rel. loss   ",
                "Abs. loss   ",
            )
            + "\n"
            + "-" * 83
        )
    
    for i in range(max_iterations):
        old_fval = f_val

        # compute the gradient
        tens = C12 + alpha * gwggrad(constC, hC1, hC2, T)
        T = solver(p, q, tens, reg=epsilon)
        
        f_val = gwloss(constC, hC1, hC2, T)
        abs_delta_fval = abs(f_val - old_fval)
        relative_delta_fval = abs_delta_fval / abs(f_val)
        
        if verbose:
            print(
                f"{i + 1:5d}|{f_val:8e}|{relative_delta_fval:8e}|{abs_delta_fval:8e}"
            )
        
        if relative_delta_fval <= rtol or abs_delta_fval <= atol:
            break
    
    return T

In [6]:
def init_sim(flow_type: Literal['bifurcation', 'convergent', 'partial_convergent', 'mistmatched_clusters'],
             seed: int = 257, plot: bool = True, **kwargs):
    fpath = f"{flow_type}_sim.pickle"
    if os.path.isfile(fpath):
        with open(fpath, "rb") as fin:
            return pickle.load(fin)
    
    start = perf_counter()
    np.random.seed(seed) 
    if flow_type == 'bifurcation':
        timescale = 1
    else:
        timescale = 100

    x0_speed = 1/timescale
    sim_params = sim.SimulationParameters(division_time_std = 0.01*timescale,
                                          flow_type = flow_type,
                                          x0_speed = x0_speed,
                                          mutation_rate = 1/timescale,
                                          mean_division_time = 1.1*timescale,
                                          timestep = 0.001*timescale,
                                          **kwargs)

    # These parameters can be adjusted freely.
    # As is, they replicate the plots in the paper for the fully convergent simulation.
    mean_x0_early = 2
    time_early = 7.4*timescale # Time when early cells are sampled
    time_late = time_early + 4*timescale # Time when late cells are sampled
    x0_initial = mean_x0_early -time_early*x0_speed
    initial_cell = sim.Cell(np.array([x0_initial, 0, 0]), np.zeros(sim_params.barcode_length))
    sample_times = {'early' : time_early, 'late' : time_late}



    # Choosing which of the three dimensions to show in later plots
    if flow_type == 'mismatched_clusters':
        dimensions_to_plot = [1, 2]
    else:
        dimensions_to_plot = [0, 1]

    ## Running the simulation
    sample = sim.sample_descendants(initial_cell.deepcopy(), time_late, sim_params)

    # Extracting trees and barcode matrices
    true_trees = {'late': sim_inf.list_tree_to_digraph(sample)}
    true_trees['late'].nodes['root']['cell'] = initial_cell
    true_trees['early'] = sim_inf.truncate_tree(true_trees['late'], sample_times['early'], sim_params)

    # Computing the ground-truth coupling
    true_coupling = sim_inf.get_true_coupling(true_trees['early'], true_trees['late'])
    
    data_arrays = {'late': sim_inf.extract_data_arrays(true_trees['late']),
                   'early': sim_inf.extract_data_arrays(true_trees['early'])}
    rna_arrays = {'late': data_arrays['late'][0]}
    barcode_arrays = {'early': data_arrays['early'][1], 'late': data_arrays['late'][1]}

    rna_arrays['early'] = sim_inf.extract_data_arrays(true_trees['early'])[0]
    num_cells = {'early': rna_arrays['early'].shape[0], 'late': rna_arrays['late'].shape[0]}

    print("Times:", sample_times)
    print("Number of cells:", num_cells)
    
        # Creating a copy of the true tree for use in LineageOT
    true_trees['late, annotated'] = copy.deepcopy(true_trees['late'])
    sim_inf.add_node_times_from_division_times(true_trees['late, annotated'])

    sim_inf.add_nodes_at_time(true_trees['late, annotated'], sample_times['early'])
    
    if plot:
        # Scatter plot of cell states
        cmap = "coolwarm"
        colors = [plt.get_cmap(cmap)(0), plt.get_cmap(cmap)(256)]
        for a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):
            plt.scatter(a[:, dimensions_to_plot[0]],
                        a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)

        plt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))
        plt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))
        plt.legend()
        
    # Infer ancestor locations for the late cells based on the true lineage tree
    observed_nodes = [n for n in sim_inf.get_leaves(true_trees['late, annotated'], include_root=False)]
    sim_inf.add_conditional_means_and_variances(true_trees['late, annotated'], observed_nodes)

    ancestor_info = {'true tree': sim_inf.get_ancestor_data(true_trees['late, annotated'], sample_times['early'])}
    
    # True distances
    true_distances = {key: sim_inf.compute_tree_distances(true_trees[key]) for key in true_trees}
    
    rate_estimate = sim_inf.rate_estimator(barcode_arrays['late'], sample_times['late'])

    print("Fraction unmutated barcodes: ", {key:np.sum(barcode_arrays[key] == 0)/barcode_arrays[key].size
                                            for key in barcode_arrays})
    print("Rate estimate: ", rate_estimate)
    print("True rate: ", sim_params.mutation_rate / sim_params.barcode_length)
    print("Rate accuracy: ", rate_estimate*sim_params.barcode_length/sim_params.mutation_rate)
    
    # Compute Hamming distance matrices for neighbor joining

    hamming_distances_with_roots = {
        'early': sim_inf.barcode_distances(np.concatenate([barcode_arrays['early'],
                                                           np.zeros([1,sim_params.barcode_length])])),
        'late': sim_inf.barcode_distances(np.concatenate([barcode_arrays['late'],
                                                          np.zeros([1,sim_params.barcode_length])]))
    }
    fitted_tree = sim_inf.neighbor_join(hamming_distances_with_roots['late'])
    fitted_tree_early = sim_inf.neighbor_join(hamming_distances_with_roots['early'])
    
    # Annotate fitted tree with internal node times

    sim_inf.add_leaf_barcodes(fitted_tree, barcode_arrays['late'])
    sim_inf.add_leaf_x(fitted_tree, rna_arrays['late'])
    sim_inf.add_leaf_times(fitted_tree, sample_times['late'])
    sim_inf.annotate_tree(fitted_tree,
                          rate_estimate*np.ones(sim_params.barcode_length),
                          time_inference_method = 'least_squares')
    
    # Add inferred ancestor nodes and states
    sim_inf.add_node_times_from_division_times(fitted_tree)
    sim_inf.add_nodes_at_time(fitted_tree, sample_times['early'])
    observed_nodes = [n for n in sim_inf.get_leaves(fitted_tree, include_root = False)]
    sim_inf.add_conditional_means_and_variances(fitted_tree, observed_nodes)
    ancestor_info['fitted tree'] = sim_inf.get_ancestor_data(fitted_tree, sample_times['early'])
    
    fitted_tree_distances = sim_inf.compute_tree_distances(fitted_tree)
    hamming_distances_late = hamming_distances_with_roots['late'] / rate_estimate
    
    
    sim_inf.add_leaf_barcodes(fitted_tree_early, barcode_arrays['early'])
    sim_inf.add_leaf_x(fitted_tree_early, rna_arrays['early'])
    sim_inf.add_leaf_times(fitted_tree_early, sample_times['early'])
    sim_inf.annotate_tree(fitted_tree_early,
                          rate_estimate*np.ones(sim_params.barcode_length),
                          time_inference_method = 'least_squares')
    sim_inf.add_node_times_from_division_times(fitted_tree_early)
    
    fitted_tree_early_distances = sim_inf.compute_tree_distances(fitted_tree_early)

    
    # Add inferred ancestor nodes and states
    sim_inf.add_node_times_from_division_times(fitted_tree)
    sim_inf.add_nodes_at_time(fitted_tree, sample_times['early'])

    end = perf_counter() - start
    print(f"Time: {end}")
    
    early_time_rna_cost = ot.utils.dist(rna_arrays['early'], sim_inf.extract_ancestor_data_arrays(true_trees['late'], sample_times['early'], sim_params)[0])
    late_time_rna_cost = ot.utils.dist(rna_arrays['late'], rna_arrays['late'])
    
    res = stn(sim_inf, ancestor_info, rna_arrays, true_coupling, true_distances, barcode_arrays,
              fitted_tree_early_distances, fitted_tree_distances, hamming_distances_with_roots,
              early_time_rna_cost, late_time_rna_cost)
    
    #with open(fpath, "wb") as fout:
    #    pickle.dump(tuple(res), fout)
        
    return res

In [7]:
def benchmark_moscot(sim: stn, *, alpha: float, epsilon: Optional[float] = None,
                     tree_type: str = 'fitted tree', scale: str = "max", **kwargs):  
    if tree_type == 'barcodes':
        e = sim_inf.barcode_distances(sim.barcode_arrays['early'])
        l = sim_inf.barcode_distances(sim.barcode_arrays['late'])
    elif tree_type == 'fitted tree':
        e = sim.fitted_tree_distances_early
        l = sim.fitted_tree_distances_late
    elif tree_type == 'true tree':
        e = sim.true_distances['early']
        l = sim.true_distances['late']
    else:
        raise NotImplementedError(tree_type)
    e = create_geometry(e, scale=scale)
    l = create_geometry(l, scale=scale)
    joint = create_geometry(ot.utils.dist(sim.rna_arrays['early'], sim.rna_arrays['late']), scale=scale)
    
    max_iterations = kwargs.pop("max_iterations", 20)
    rtol = kwargs.pop("rtol", 1e-6)
    atol = kwargs.pop("atol", 1e-6)
    
    start = perf_counter()
    if alpha == 0:
        u = Unbalanced(epsilon=epsilon)
        u.fit(joint)
        tmat = np.asarray(u.matrix)
        conv = [u.converged]
    elif alpha == 1:
        gw = GW(epsilon=epsilon)
        gw.fit(e, l)
        tmat = np.asarray(gw.matrix)
        conv = gw.converged_sinkhorn
    else:
        fgw = FusedGW(alpha=alpha, epsilon=epsilon, **kwargs)
        fgw.fit(e, l, joint, linesearch=False, verbose=False, max_iterations=max_iterations, rtol=rtol, atol=atol)
        tmat = np.asarray(fgw.matrix)
        conv = fgw.converged_sinkhorn
        
    time = perf_counter() - start
    print(f"Time: {time}")
    
    if not np.all(np.isfinite(tmat)):
        raise AssertionError("Convergence issue - not all values are finite.")
    
    early_cost = float(sim.sim_info.OT_cost(tmat, sim.early_time_rna_cost))
    late_cost = float(sim.sim_info.OT_cost(sim_eval.expand_coupling(tmat, sim.true_coupling, sim.late_time_rna_cost),
                                          sim.late_time_rna_cost))
    norm_diff = np.linalg.norm(tmat - sim.true_coupling)
        
    return bnt(tmat, early_cost, late_cost, norm_diff, conv, time)


def benchmark_lineageOT(sim: stn, *, epsilon: float, tree_type: str = 'fitted tree',  **kwargs):
    cmat = ot.utils.dist(sim.rna_arrays['early'], sim.ancestor_info[tree_type][0]) @ np.diag(sim.ancestor_info[tree_type][1] ** (-1))
    
    # Epsilon scaling is more robust at smaller epsilon, but slower than simple sinkhorn
    f = ot.sinkhorn if epsilon >= 0.1 else ot.bregman.sinkhorn_epsilon_scaling
    start = perf_counter()
    tmat = f([], [], cmat, epsilon * np.mean(sim.ancestor_info[tree_type][1] ** (-1)), **kwargs)
    time = perf_counter() - start
    
    early_cost = float(sim.sim_info.OT_cost(tmat, sim.early_time_rna_cost))
    late_cost = float(sim.sim_info.OT_cost(sim_eval.expand_coupling(tmat, sim.true_coupling, sim.late_time_rna_cost),
                                          sim.late_time_rna_cost))
    norm_diff = np.linalg.norm(tmat - sim.true_coupling)
    
    return bnt(tmat, early_cost, late_cost, norm_diff, None, time)

In [8]:
def gridsearch(sim: stn, *, alphas: Sequence[float], epsilons: Sequence[float], scale: str = "max", **kwargs) -> Dict[float, Dict[float, bnt]]:
    res = defaultdict(defaultdict)
    for alpha in alphas:
        for epsilon in epsilons:
            try:
                print(f"alpha={alpha}, epsilon={epsilon} scale={scale}")
                res[alpha][epsilon] = benchmark_moscot(sim, alpha=alpha, epsilon=epsilon, scale=scale, **kwargs)
            except Exception as e:
                print(traceback.format_exc())
                res[alpha][epsilon] = None
    
    return {a: {e: v for e, v in vs.items()} for a, vs in res.items()}

def gridsearch_lineageOT(sim: stn, *, epsilons: Sequence[float], **kwargs) -> Dict[float, Dict[float, bnt]]:
    res = defaultdict(defaultdict)
    for alpha in [None]:
        for epsilon in epsilons:
            if epsilon is None:
                res[alpha][epsilon] = None
                continue
            try:
                print(f"alpha={alpha}, epsilon={epsilon}")
                res[alpha][epsilon] = benchmark_lineageOT(sim, epsilon=epsilon, **kwargs)
            except Exception as e:
                print(traceback.format_exc())
                res[alpha][epsilon] = None
    
    return {a: {e: v for e, v in vs.items()} for a, vs in res.items()}

In [9]:
sim_params = {}
# original epsilons
epsilons = [None] + sorted([1e-4, 5e-3, 1e-3, 5e-2, 1e-2, 5e-1, 1e-1, 1])
alphas = list(np.round(np.linspace(0.0, 1, 21, dtype=np.float64), 2))
root = 'temp4'
print(epsilons)
print(alphas)
len(epsilons), len(alphas)

[None, 0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1]
[0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]


(9, 21)

In [10]:
!mkdir -p $root

In [11]:
for flow_type in ["bifurcation"]:  # , "convergent", "partial_convergent", "mismatched_clusters"]:
    s = init_sim(flow_type, plot=False, **sim_params)
    for kind in ['lineageOT', 'moscot']:
        for scale in ["max", "mean", "median"]:
            if kind == 'lineageOT':
                if scale != 'max':
                    continue
                fname = f"{flow_type}_lot.pickle"
                res = gridsearch_lineageOT(s, epsilons=epsilons, numItermax=100, stopThr=1e-9)
            else:
                fname = f"{flow_type}_{scale}.pickle"
                res = gridsearch(s, alphas=alphas, epsilons=epsilons, scale=scale,
                                 max_iterations=100, rtol=1e-9, atol=1e-9)

            with open(f"{root}/{fname}", "wb") as fout:
                pickle.dump(res, fout)

Times: {'early': 7.4, 'late': 11.4}
Number of cells: {'early': 64, 'late': 1024}
Fraction unmutated barcodes:  {'early': 0.6104166666666667, 'late': 0.47311197916666664}
Rate estimate:  0.06565115579693222
True rate:  0.06666666666666667
Rate accuracy:  0.9847673369539832
     pcost       dcost       gap    pres   dres
 0: -6.3033e+04 -6.5712e+04  2e+04  3e-01  3e-01
 1: -6.2703e+04 -6.8508e+04  1e+04  9e-02  1e-01
 2: -6.2500e+04 -6.6193e+04  4e+03  3e-02  3e-02
 3: -6.2907e+04 -6.4419e+04  2e+03  1e-02  1e-02
 4: -6.2914e+04 -6.4359e+04  2e+03  8e-03  9e-03
 5: -6.3065e+04 -6.3855e+04  8e+02  3e-03  3e-03
 6: -6.3152e+04 -6.3566e+04  4e+02  1e-16  2e-16
 7: -6.3240e+04 -6.3320e+04  8e+01  1e-16  4e-16
 8: -6.3259e+04 -6.3266e+04  8e+00  1e-16  3e-16
 9: -6.3260e+04 -6.3261e+04  4e-01  1e-16  9e-16
10: -6.3261e+04 -6.3261e+04  1e-02  1e-16  8e-16
Optimal solution found.
     pcost       dcost       gap    pres   dres
 0: -1.4120e+03 -1.6186e+03  1e+03  4e-01  4e-01
 1: -1.4091e+03 -1.



alpha=0.0, epsilon=None scale=max
Time: 1.1252877071965486
alpha=0.0, epsilon=0.0001 scale=max
Time: 0.053816302912309766
alpha=0.0, epsilon=0.001 scale=max
Time: 0.05900649703107774
alpha=0.0, epsilon=0.005 scale=max
Time: 0.04259821795858443
alpha=0.0, epsilon=0.01 scale=max
Time: 0.04231681511737406
alpha=0.0, epsilon=0.05 scale=max
Time: 0.041424970142543316
alpha=0.0, epsilon=0.1 scale=max
Time: 0.043217157013714314
alpha=0.0, epsilon=0.5 scale=max
Time: 0.0430318471044302
alpha=0.0, epsilon=1 scale=max
Time: 0.04277858906425536
alpha=0.05, epsilon=None scale=max
Time: 1.2184976530261338
alpha=0.05, epsilon=0.0001 scale=max


KeyboardInterrupt: 