# LineageOT benchmark

In [1]:
# pip install git+https://github.com/aforr/LineageOT@master cvxopt ete3

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

from ott.geometry.geometry import Geometry
from ott.geometry.pointcloud import PointCloud
from jax import numpy as jnp
import seaborn as sns

from time import perf_counter
from moscot import FusedGW
import pickle
import os

In [3]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import ot

import lineageot.simulation as sim
import lineageot.evaluation as sim_eval
import lineageot.inference as sim_inf

from typing import Literal, Optional, Sequence, Dict
import traceback
from collections import namedtuple, defaultdict

In [4]:
bnt = namedtuple("bnt", "tmat early_cost late_cost norm_diff converged time")
stn = namedtuple("sim",
                 "sim_info "
                 "ancestor_info "
                 "rna_arrays "
                 "true_coupling "
                 "true_distances "
                 "barcode_arrays "
                 "fitted_tree_distances_early "
                 "fitted_tree_distances_late "
                 "hamming_distances_late "
                 "early_time_rna_cost "
                 "late_time_rna_cost")

In [5]:
def init_sim(flow_type: Literal['bifurcation', 'convergent', 'partial_convergent', 'mistmatched_clusters'],
             seed: int = 257, plot: bool = True, **kwargs):
    fpath = f"{flow_type}_sim.pickle"
    if os.path.isfile(fpath):
        with open(fpath, "rb") as fin:
            return pickle.load(fin)
    
    start = perf_counter()
    np.random.seed(seed) 
    if flow_type == 'bifurcation':
        timescale = 1
    else:
        timescale = 100

    x0_speed = 1/timescale
    sim_params = sim.SimulationParameters(division_time_std = 0.01*timescale,
                                          flow_type = flow_type,
                                          x0_speed = x0_speed,
                                          mutation_rate = 1/timescale,
                                          mean_division_time = 1.1*timescale,
                                          timestep = 0.001*timescale,
                                          **kwargs)

    # These parameters can be adjusted freely.
    # As is, they replicate the plots in the paper for the fully convergent simulation.
    mean_x0_early = 2
    time_early = 7.4*timescale # Time when early cells are sampled
    time_late = time_early + 4*timescale # Time when late cells are sampled
    x0_initial = mean_x0_early -time_early*x0_speed
    initial_cell = sim.Cell(np.array([x0_initial, 0, 0]), np.zeros(sim_params.barcode_length))
    sample_times = {'early' : time_early, 'late' : time_late}



    # Choosing which of the three dimensions to show in later plots
    if flow_type == 'mismatched_clusters':
        dimensions_to_plot = [1, 2]
    else:
        dimensions_to_plot = [0, 1]

    ## Running the simulation
    sample = sim.sample_descendants(initial_cell.deepcopy(), time_late, sim_params)

    # Extracting trees and barcode matrices
    true_trees = {'late': sim_inf.list_tree_to_digraph(sample)}
    true_trees['late'].nodes['root']['cell'] = initial_cell
    true_trees['early'] = sim_inf.truncate_tree(true_trees['late'], sample_times['early'], sim_params)

    # Computing the ground-truth coupling
    true_coupling = sim_inf.get_true_coupling(true_trees['early'], true_trees['late'])
    
    data_arrays = {'late': sim_inf.extract_data_arrays(true_trees['late']),
                   'early': sim_inf.extract_data_arrays(true_trees['early'])}
    rna_arrays = {'late': data_arrays['late'][0]}
    barcode_arrays = {'early': data_arrays['early'][1], 'late': data_arrays['late'][1]}

    rna_arrays['early'] = sim_inf.extract_data_arrays(true_trees['early'])[0]
    num_cells = {'early': rna_arrays['early'].shape[0], 'late': rna_arrays['late'].shape[0]}

    print("Times:", sample_times)
    print("Number of cells:", num_cells)
    
        # Creating a copy of the true tree for use in LineageOT
    true_trees['late, annotated'] = copy.deepcopy(true_trees['late'])
    sim_inf.add_node_times_from_division_times(true_trees['late, annotated'])

    sim_inf.add_nodes_at_time(true_trees['late, annotated'], sample_times['early'])
    
    if plot:
        # Scatter plot of cell states
        cmap = "coolwarm"
        colors = [plt.get_cmap(cmap)(0), plt.get_cmap(cmap)(256)]
        for a,label, c in zip([rna_arrays['early'], rna_arrays['late']], ['Early cells', 'Late cells'], colors):
            plt.scatter(a[:, dimensions_to_plot[0]],
                        a[:, dimensions_to_plot[1]], alpha = 0.4, label = label, color = c)

        plt.xlabel('Gene ' + str(dimensions_to_plot[0] + 1))
        plt.ylabel('Gene ' + str(dimensions_to_plot[1] + 1))
        plt.legend()
        
    # Infer ancestor locations for the late cells based on the true lineage tree
    observed_nodes = [n for n in sim_inf.get_leaves(true_trees['late, annotated'], include_root=False)]
    sim_inf.add_conditional_means_and_variances(true_trees['late, annotated'], observed_nodes)

    ancestor_info = {'true tree': sim_inf.get_ancestor_data(true_trees['late, annotated'], sample_times['early'])}
    
    # True distances
    true_distances = {key: sim_inf.compute_tree_distances(true_trees[key]) for key in true_trees}
    
    rate_estimate = sim_inf.rate_estimator(barcode_arrays['late'], sample_times['late'])

    print("Fraction unmutated barcodes: ", {key:np.sum(barcode_arrays[key] == 0)/barcode_arrays[key].size
                                            for key in barcode_arrays})
    print("Rate estimate: ", rate_estimate)
    print("True rate: ", sim_params.mutation_rate / sim_params.barcode_length)
    print("Rate accuracy: ", rate_estimate*sim_params.barcode_length/sim_params.mutation_rate)
    
    # Compute Hamming distance matrices for neighbor joining

    hamming_distances_with_roots = {
        'early': sim_inf.barcode_distances(np.concatenate([barcode_arrays['early'],
                                                           np.zeros([1,sim_params.barcode_length])])),
        'late': sim_inf.barcode_distances(np.concatenate([barcode_arrays['late'],
                                                          np.zeros([1,sim_params.barcode_length])]))
    }
    fitted_tree = sim_inf.neighbor_join(hamming_distances_with_roots['late'])
    fitted_tree_early = sim_inf.neighbor_join(hamming_distances_with_roots['early'])
    
    # Annotate fitted tree with internal node times

    sim_inf.add_leaf_barcodes(fitted_tree, barcode_arrays['late'])
    sim_inf.add_leaf_x(fitted_tree, rna_arrays['late'])
    sim_inf.add_leaf_times(fitted_tree, sample_times['late'])
    sim_inf.annotate_tree(fitted_tree,
                          rate_estimate*np.ones(sim_params.barcode_length),
                          time_inference_method = 'least_squares')
    
    # Add inferred ancestor nodes and states
    sim_inf.add_node_times_from_division_times(fitted_tree)
    sim_inf.add_nodes_at_time(fitted_tree, sample_times['early'])
    observed_nodes = [n for n in sim_inf.get_leaves(fitted_tree, include_root = False)]
    sim_inf.add_conditional_means_and_variances(fitted_tree, observed_nodes)
    ancestor_info['fitted tree'] = sim_inf.get_ancestor_data(fitted_tree, sample_times['early'])
    
    fitted_tree_distances = sim_inf.compute_tree_distances(fitted_tree)
    hamming_distances_late = hamming_distances_with_roots['late'] / rate_estimate
    
    
    sim_inf.add_leaf_barcodes(fitted_tree_early, barcode_arrays['early'])
    sim_inf.add_leaf_x(fitted_tree_early, rna_arrays['early'])
    sim_inf.add_leaf_times(fitted_tree_early, sample_times['early'])
    sim_inf.annotate_tree(fitted_tree_early,
                          rate_estimate*np.ones(sim_params.barcode_length),
                          time_inference_method = 'least_squares')
    sim_inf.add_node_times_from_division_times(fitted_tree_early)
    
    fitted_tree_early_distances = sim_inf.compute_tree_distances(fitted_tree_early)

    
    # Add inferred ancestor nodes and states
    sim_inf.add_node_times_from_division_times(fitted_tree)
    sim_inf.add_nodes_at_time(fitted_tree, sample_times['early'])

    end = perf_counter() - start
    print(f"Time: {end}")
    
    early_time_rna_cost = ot.utils.dist(rna_arrays['early'], sim_inf.extract_ancestor_data_arrays(true_trees['late'], sample_times['early'], sim_params)[0])
    late_time_rna_cost = ot.utils.dist(rna_arrays['late'], rna_arrays['late'])
    
    res = stn(sim_inf, ancestor_info, rna_arrays, true_coupling, true_distances, barcode_arrays,
              fitted_tree_early_distances, fitted_tree_distances, hamming_distances_with_roots,
              early_time_rna_cost, late_time_rna_cost)
    
    #with open(fpath, "wb") as fout:
    #    pickle.dump(tuple(res), fout)
        
    return res

In [6]:
def benchmark_moscot(sim: stn, *, alpha: float, epsilon: Optional[float] = None,
                     tree_type: str = 'fitted tree', rescale: bool = True, **kwargs):
    def create_geometry(cost_matrix: np.ndarray) -> Geometry:
        cost_matrix = jnp.array(cost_matrix)
        if rescale:
            cost_matrix /= cost_matrix.max()
            assert cost_matrix.max() == 1.0
        assert (cost_matrix >= 0).all()
        return Geometry(cost_matrix=cost_matrix)
    
    if tree_type == 'fitted tree':
        # e = sim_inf.barcode_distances(sim.barcode_arrays['early'])
        # l = sim_inf.barcode_distances(sim.barcode_arrays['late'])
        e = sim.fitted_tree_distances_early
        l = sim.fitted_tree_distances_late
    else:
        e = create_geometry(sim.true_distances['early'])
        l = create_geometry(sim.true_distances['late'])
    joint = create_geometry(ot.utils.dist(sim.rna_arrays['early'], sim.rna_arrays['late']))
    
    max_iterations = kwargs.pop("max_iterations", 20)
    rtol = kwargs.pop("rtol", 1e-6)
    atol = kwargs.pop("atol", 1e-6)
    
    fgw = FusedGW(alpha=alpha, epsilon=epsilon, **kwargs)
    start = perf_counter()
    fgw.fit(e, l, joint, linesearch=False, verbose=False, max_iterations=max_iterations, rtol=rtol, atol=atol)
    time = perf_counter() - start
    tmat = np.array(fgw.matrix)
    print(f"Time: {time}")
    
    early_cost = float(sim.sim_info.OT_cost(tmat, sim.early_time_rna_cost))
    late_cost = float(sim.sim_info.OT_cost(sim_eval.expand_coupling(tmat, sim.true_coupling, sim.late_time_rna_cost),
                                          sim.late_time_rna_cost))
    norm_diff = np.linalg.norm(tmat - sim.true_coupling)
        
    return bnt(tmat, early_cost, late_cost, norm_diff, fgw.converged_sinkhorn, time)


def benchmark_lineageOT(sim: stn, *, epsilon: float, tree_type: str = 'fitted tree',  **kwargs):
    cmat = ot.utils.dist(sim.rna_arrays['early'], sim.ancestor_info[tree_type][0]) @ np.diag(sim.ancestor_info[tree_type][1] ** (-1))
    
    # Epsilon scaling is more robust at smaller epsilon, but slower than simple sinkhorn
    f = ot.sinkhorn if epsilon >= 0.1 else ot.bregman.sinkhorn_epsilon_scaling
    start = perf_counter()
    tmat = f([], [], cmat, epsilon * np.mean(sim.ancestor_info[tree_type][1] ** (-1)), **kwargs)
    time = perf_counter() - start
    
    early_cost = float(sim.sim_info.OT_cost(tmat, sim.early_time_rna_cost))
    late_cost = float(sim.sim_info.OT_cost(sim_eval.expand_coupling(tmat, sim.true_coupling, sim.late_time_rna_cost),
                                          sim.late_time_rna_cost))
    norm_diff = np.linalg.norm(tmat - sim.true_coupling)
    
    return bnt(tmat, early_cost, late_cost, norm_diff, None, time)

In [7]:
def gridsearch(sim: stn, *, alphas: Sequence[float], epsilons: Sequence[float], rescale: bool = True, **kwargs) -> Dict[float, Dict[float, bnt]]:
    res = defaultdict(defaultdict)
    for alpha in alphas:
        for epsilon in epsilons:
            try:
                print(f"alpha={alpha}, epsilon={epsilon}")
                res[alpha][epsilon] = benchmark_moscot(sim, alpha=alpha, epsilon=epsilon, rescale=rescale, **kwargs)
            except Exception as e:
                print(traceback.format_exc())
                res[alpha][epsilon] = None
    
    return {a: {e: v for e, v in vs.items()} for a, vs in res.items()}

def gridsearch_lineageOT(sim: stn, *, epsilons: Sequence[float], **kwargs) -> Dict[float, Dict[float, bnt]]:
    res = defaultdict(defaultdict)
    for alpha in [None]:
        for epsilon in epsilons:
            if epsilon is None:
                res[alpha][epsilon] = None
                continue
            try:
                print(f"alpha={alpha}, epsilon={epsilon}")
                res[alpha][epsilon] = benchmark_lineageOT(sim, epsilon=epsilon, **kwargs)
            except Exception as e:
                print(traceback.format_exc())
                res[alpha][epsilon] = None
    
    return {a: {e: v for e, v in vs.items()} for a, vs in res.items()}

In [8]:
def plot(res, suptitle="", figsize=(12, 3), dpi=300):
    vals = {
        "early_cost": defaultdict(defaultdict),
        "late_cost": defaultdict(defaultdict),
        "norm_diff": defaultdict(defaultdict),
        "time": defaultdict(defaultdict),
    }
    
    for alpha, vs in res.items():
        for epsilon, bench in vs.items():
            for attr, container in vals.items():
                container[alpha][epsilon] = getattr(res[alpha][epsilon], attr)
                
    fig, axes = plt.subplots(1, 4, tight_layout=True, dpi=dpi, figsize=figsize)
    fig.suptitle(flow_type)
    axes = np.ravel([axes])
    
    for ax, (attr, val) in zip(axes, vals.items()):
        data = pd.DataFrame(val)
        sns.heatmap(data, ax=ax, annot=True, fmt='.3g', cmap='viridis')
        ax.set_title(attr)
        ax.set_xlabel("alpha")
        ax.set_ylabel("epsilon")
        
    fig.savefig(f"{suptitle}.pdf")

In [9]:
sim_params = {}
# original epsilons
epsilons = [None, 1e-4, 5e-3, 1e-3, 5e-2] + list(np.logspace(-2, 3, 15))
alphas = list(np.linspace(0.025, 0.975, 20))
print(epsilons)
print(alphas)
len(epsilons), len(alphas)

[None, 0.0001, 0.005, 0.001, 0.05, 0.01, 0.022758459260747887, 0.05179474679231213, 0.11787686347935872, 0.2682695795279726, 0.6105402296585329, 1.3894954943731375, 3.1622776601683795, 7.196856730011521, 16.378937069540648, 37.27593720314942, 84.83428982440725, 193.06977288832496, 439.3970560760795, 1000.0]
[0.025, 0.075, 0.125, 0.175, 0.22499999999999998, 0.27499999999999997, 0.325, 0.375, 0.425, 0.475, 0.5249999999999999, 0.575, 0.625, 0.6749999999999999, 0.725, 0.7749999999999999, 0.825, 0.875, 0.9249999999999999, 0.975]


(20, 20)

In [10]:
def plot_metrics(couplings, cost_func, cost_func_name, epsilons, log = False, points=False, scale=1.0, label_font_size=18, tick_font_size=12):
    """
    Plots cost_func evaluated as a function of epsilon
    """
    zero_offset = epsilons[0]/2
    all_ys = []
    if "lineageOT" in couplings.keys():
        ys = np.array([cost_func(c) for c in [couplings['lineage entropic rna ' + str(e)] for e in epsilons]])
        plt.plot(epsilons, ys/scale, label = "LineageOT, true tree")
        if points:
            plt.scatter([zero_offset], [cost_func(couplings["lineageOT"])/scale])
        all_ys.append(ys)
    if "OT" in couplings.keys():
        ys = np.array([cost_func(c) for c in [couplings['entropic rna ' + str(e)] for e in epsilons]])
        plt.plot(epsilons, ys/scale, label = "Entropic OT")
        if points:
            plt.scatter([zero_offset], [cost_func(couplings["OT"])/scale])
        all_ys.append(ys)
    if True: #"lineageOT, fitted" in couplings.keys():
        ys = np.array([cost_func(c) for c in [couplings['fitted lineage rna ' + str(e)] for e in epsilons]])
        plt.plot(epsilons, ys/scale, label = "LineageOT, fitted tree")
        if points:
            plt.scatter([zero_offset], [cost_func(couplings["lineageOT, fitted"])/scale])
        all_ys.append(ys)

    plt.ylabel(cost_func_name, fontsize=label_font_size)
    plt.xlabel("Entropy parameter", fontsize=label_font_size)
    plt.xscale("log")

    plt.xticks(fontsize=tick_font_size)
    plt.yticks(fontsize=tick_font_size)

    if points:
        plt.xlim([0.9*zero_offset, epsilons[-1]])
    else:
        plt.xlim([epsilons[0], epsilons[-1]])

    ylims = plt.ylim([0, None])
    # upper limit should be at least 1
    plt.ylim([0, ylims[1]])

    plt.legend(fontsize=tick_font_size)
    return all_ys

In [None]:
for kind in ['lineageOT', 'moscot']:
    for flow_type in ["bifurcation"]:
        s = init_sim(flow_type, plot=False, **sim_params)
        if kind == 'lineageOT':
            fname = f"{flow_type}_lot.pickle"
            res = gridsearch_lineageOT(s, epsilons=epsilons, numItermax=100, stopThr=1e-9)
        else:
            fname = f"{flow_type}.pickle"
            res = gridsearch(s, alphas=alphas, epsilons=epsilons, max_iterations=100, rtol=1e-9, atol=1e-9)

        with open(f"temp3/{fname}", "wb") as fout:
            pickle.dump(res, fout)

Times: {'early': 7.4, 'late': 11.4}
Number of cells: {'early': 64, 'late': 1024}
Fraction unmutated barcodes:  {'early': 0.6104166666666667, 'late': 0.47311197916666664}
Rate estimate:  0.06565115579693222
True rate:  0.06666666666666667
Rate accuracy:  0.9847673369539832
     pcost       dcost       gap    pres   dres
 0: -6.3033e+04 -6.5712e+04  2e+04  3e-01  3e-01
 1: -6.2703e+04 -6.8508e+04  1e+04  9e-02  1e-01
 2: -6.2500e+04 -6.6193e+04  4e+03  3e-02  3e-02
 3: -6.2907e+04 -6.4419e+04  2e+03  1e-02  1e-02
 4: -6.2914e+04 -6.4359e+04  2e+03  8e-03  9e-03
 5: -6.3065e+04 -6.3855e+04  8e+02  3e-03  3e-03
 6: -6.3152e+04 -6.3566e+04  4e+02  1e-16  2e-16
 7: -6.3240e+04 -6.3320e+04  8e+01  1e-16  4e-16
 8: -6.3259e+04 -6.3266e+04  8e+00  1e-16  3e-16
 9: -6.3260e+04 -6.3261e+04  4e-01  1e-16  9e-16
10: -6.3261e+04 -6.3261e+04  1e-02  1e-16  8e-16
Optimal solution found.
     pcost       dcost       gap    pres   dres
 0: -1.4120e+03 -1.6186e+03  1e+03  4e-01  4e-01
 1: -1.4091e+03 -1.



alpha=0.025, epsilon=None
Time: 9.465250011067837
alpha=0.025, epsilon=0.0001
Time: 421.75122000603005


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: 8.213648656107964e-06


alpha=0.025, epsilon=0.005
Time: 44.81513454997912


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: -6.666492002871394e-08


alpha=0.025, epsilon=0.001
Time: 45.893453346099705


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: 1.6268406621833265e-07


alpha=0.025, epsilon=0.05
Time: 39.76682848902419
alpha=0.025, epsilon=0.01
Time: 128.88580902805552


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: 2.9434099158009985e-08


alpha=0.025, epsilon=0.022758459260747887
Time: 45.7970696198754


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: 5.7247792284442767e-08


alpha=0.025, epsilon=0.05179474679231213
Time: 43.55963896890171
alpha=0.025, epsilon=0.11787686347935872
Time: 45.14254011097364
alpha=0.025, epsilon=0.2682695795279726
Time: 37.87993436399847
alpha=0.025, epsilon=0.6105402296585329
Time: 54.98434304399416
alpha=0.025, epsilon=1.3894954943731375
Time: 149.5640204299707
alpha=0.025, epsilon=3.1622776601683795
Time: 71.87857722095214
alpha=0.025, epsilon=7.196856730011521
Time: 115.23756146989763
alpha=0.025, epsilon=16.378937069540648
Time: 69.53246525605209
alpha=0.025, epsilon=37.27593720314942
Time: 64.98596913879737
alpha=0.025, epsilon=84.83428982440725
Time: 64.75873325089924
alpha=0.025, epsilon=193.06977288832496
Time: 89.0975817530416
alpha=0.025, epsilon=439.3970560760795
Time: 56.61757718003355
alpha=0.025, epsilon=1000.0
Time: 20.01967085408978
alpha=0.075, epsilon=None
Time: 8.822277083992958
alpha=0.075, epsilon=0.0001
Time: 394.3702399979811


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: -2.4135585588558328e-05


alpha=0.075, epsilon=0.005
Time: 26.998654825845733


If total mass - 1 is small, this may not significantly affect downstream results.
Total mass - 1: -9.099730419181995e-07


alpha=0.075, epsilon=0.001
