In [1]:
import anndata
import copy
import datetime
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import ot
import pandas as pd
import pickle
import scanpy as sc
import scipy.stats
import seaborn as sb
import warnings
import anndata


import lineageot.inference
import lineageot.simulation
import lineageot.evaluation

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

from ott.geometry.geometry import Geometry
from jax import numpy as jnp

from moscot import FusedGW

from typing import *
from collections import namedtuple, defaultdict
import traceback
from time import perf_counter

In [3]:
data = namedtuple("data", "edist ldist rna_dist a b")
bnt = namedtuple("bnt", "tmat early_cost late_cost norm_diff converged time")

In [4]:
start_time = datetime.datetime.now()

In [5]:
data_path = "data/"
save_dir = "plots/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)  

In [6]:
#with open(data_path + 'pickled_filtered_anndata.p', 'rb') as file:
#    adata = pickle.load(file)
adata = anndata.read("filtered.h5ad")
adata

AnnData object with n_obs × n_vars = 46151 × 20222
    obs: 'Unnamed: 0', 'Unnamed: 0.1', 'cell', 'n.umi', 'time.point', 'batch', 'Size_Factor', 'cell.type', 'cell.subtype', 'plot.cell.type', 'raw.embryo.time', 'embryo.time', 'embryo.time.bin', 'raw.embryo.time.bin', 'lineage', 'passed_initial_QC_or_later_whitelisted', 'random_precise_lineage'
    var: 'Unnamed: 0', 'id', 'gene_short_name'

In [7]:
with open(data_path + "packer_pickle_lineage_tree.p", 'rb') as file:
    full_reference_tree = pickle.load(file)

### Normalization and preprocessing

In [8]:
# Removing partially lineage-labeled cells

adata = adata[adata.obs['lineage'].to_numpy() == adata.obs['random_precise_lineage'].to_numpy()].copy()
adata

AnnData object with n_obs × n_vars = 5123 × 20222
    obs: 'Unnamed: 0', 'Unnamed: 0.1', 'cell', 'n.umi', 'time.point', 'batch', 'Size_Factor', 'cell.type', 'cell.subtype', 'plot.cell.type', 'raw.embryo.time', 'embryo.time', 'embryo.time.bin', 'raw.embryo.time.bin', 'lineage', 'passed_initial_QC_or_later_whitelisted', 'random_precise_lineage'
    var: 'Unnamed: 0', 'id', 'gene_short_name'

#### Filtering cells + genes

In [9]:
adata.obs['n_counts'] = np.array(adata.X.sum(1)).squeeze()
adata.obs['log_counts'] = np.log(adata.obs['n_counts'])
adata.obs['n_genes'] = (adata.X > 0).sum(1)

In [10]:
%%time
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.pca(adata)
adata

CPU times: user 29.2 s, sys: 1min 3s, total: 1min 32s
Wall time: 14.8 s


AnnData object with n_obs × n_vars = 5123 × 20222
    obs: 'Unnamed: 0', 'Unnamed: 0.1', 'cell', 'n.umi', 'time.point', 'batch', 'Size_Factor', 'cell.type', 'cell.subtype', 'plot.cell.type', 'raw.embryo.time', 'embryo.time', 'embryo.time.bin', 'raw.embryo.time.bin', 'lineage', 'passed_initial_QC_or_later_whitelisted', 'random_precise_lineage', 'n_counts', 'log_counts', 'n_genes'
    var: 'Unnamed: 0', 'id', 'gene_short_name'
    uns: 'log1p', 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

In [11]:

adata.obs.index = adata.obs['cell']

adata = adata[adata.obs.index.sort_values()].copy()
# for some reason on linux we need to redo this
# possibly due to version change of anndata (or dependencies)
# it seems that the previous indexing changed the index
adata.obs.index = adata.obs['cell']

In [12]:
print("AnnData", anndata.__version__)
# issue above observed for AnnData 0.7.4

AnnData 0.7.6


### Embryo time selection

In [13]:
def assign_time_to_bin(bin_string):
    if bin_string == '< 100':
        return 75
    elif bin_string == '> 650':
        # arbitrary choice here
        return 700
    else:
        # end of time range
        return (0*int(bin_string[0:3]) + 1*int(bin_string[-3:]))
    return

In [14]:
# If you would like to save time by only running the evaluation
# on a subset of the cells, change num_subbatches to an integer larger than 1.

# randomly splitting each batch into subbatches 
# so that there are fewer cells and the evaluation runs faster
num_subbatches = 1
adata.obs['assigned_batch_time'] = adata.obs['embryo.time.bin'].map(lambda x: assign_time_to_bin(x)
                                                                   + np.random.randint(num_subbatches))

In [15]:
batches = adata.obs['assigned_batch_time'].unique()
batches = np.sort(batches)
print(batches)

[ 75 130 170 210 270 330 390 450 510 580 650 700]


In [16]:
def create_lineage_tree(adata, batch_time, reference_tree = full_reference_tree, obsm_key = 'X_pca'):
    selected_cells = adata[adata.obs['assigned_batch_time'] == batch_time]
    
    # no need to remove any nodes from the reference
    # (unobserved subtrees have no effect on the inference)
    new_tree = copy.deepcopy(full_reference_tree)
    
    bad_cell_list = []
    good_cell_list = []
    for cell in selected_cells.obs.index:
        cell_label = selected_cells.obs['random_precise_lineage'][cell]
        cell_index = adata.obs.index.get_loc(cell)
        
        parent = next(reference_tree.predecessors(cell_label))
        
        if batch_time <= new_tree.nodes[parent]['time']:
            warnings.warn('Nonpositive time to parent ('
                          + str(batch_time - new_tree.nodes[parent]['time'])
                          + ') encountered in batch '
                          + str(batch_time) 
                          + ' for cell ' + cell)
            bad_cell_list.append(cell)
            # filter this cell out
            continue
        else:
            good_cell_list.append(cell)
        
        new_tree.add_node(cell)
        new_tree.add_edge(parent, cell)
        
        new_tree.nodes[cell]['name'] = cell_label
        new_tree.nodes[cell]['time'] = batch_time
        new_tree.nodes[cell]['time_to_parent'] = batch_time - new_tree.nodes[parent]['time']
        new_tree.nodes[cell]['cell'] = lineageot.simulation.Cell(adata.obsm[obsm_key][cell_index, :], cell_label)
        
        assert(new_tree.nodes[cell]['time_to_parent'] >= 0)
        
        
    return new_tree, good_cell_list, bad_cell_list

In [17]:
%%time
trees_by_batch = [create_lineage_tree(adata, batch_time) for batch_time in batches]



CPU times: user 2.17 s, sys: 58.6 ms, total: 2.23 s
Wall time: 2.22 s


In [18]:
# Counting the number of cells kept vs. filtered out
# for having a measurement time at or before their
# reference birth time
{batch_time : (len(t[1]), len(t[2])) for t,batch_time in zip(trees_by_batch, batches)}

{75: (0, 1),
 130: (14, 1),
 170: (151, 14),
 210: (755, 79),
 270: (1215, 2),
 330: (1534, 0),
 390: (678, 0),
 450: (463, 0),
 510: (206, 0),
 580: (2, 0),
 650: (5, 0),
 700: (3, 0)}

#### Making a ground truth coupling

In [19]:
def is_ancestor(late_cell, early_cell):
    if not late_cell[-1] in 'aplrdvx':
        warnings.warn('Ancestor checking not implemented for cell ' + late_cell + ' yet.')
        return
    else:
        return early_cell in late_cell

In [20]:
def ground_truth_coupling(early_tree, late_tree, early_cells, late_cells):
    n_early = len(early_cells)
    n_late = len(late_cells)
    coupling = np.zeros([n_early, n_late])
    for c_early, i in zip(early_cells, range(n_early)):
        for c_late, j in zip(late_cells, range(n_late)):
            if is_ancestor(late_tree.nodes[c_late]['name'],
                           early_tree.nodes[c_early]['name']):
                coupling[i, j] = 1
    
    # filter out zero rows and columns
    kept_early_cells = np.where(np.sum(coupling, 1) > 0)[0]
    kept_late_cells = np.where(np.sum(coupling, 0) > 0)[0]
    coupling = coupling[np.ix_(kept_early_cells, kept_late_cells)]
    
    
    # normalize to uniform marginal on late cells
    coupling = np.dot(coupling, np.diag(np.sum(coupling, 0)**(-1)))/len(kept_late_cells)
    
    return coupling, kept_early_cells, kept_late_cells
    

In [21]:
couplings = {}

In [22]:
# Available base times: 170, 210, 270, 330, 390, 450

# The rest of the script only works with the coupling between
# early_batch_time and late_batch time
early_batch_time = 210 
late_batch_time = 270 
early_batch_index = np.where(batches == early_batch_time)[0][0]
late_batch_index = np.where(batches == late_batch_time)[0][0]
early_cells = trees_by_batch[early_batch_index][1]
late_cells = trees_by_batch[late_batch_index][1]

In [23]:
%%time
print((batches[early_batch_index], batches[late_batch_index]))
couplings['true'], kept_early_cells_index, kept_late_cells_index = ground_truth_coupling(trees_by_batch[early_batch_index][0],
                                                                             trees_by_batch[late_batch_index][0],
                                                                             trees_by_batch[early_batch_index][1],
                                                                             trees_by_batch[late_batch_index][1])

(210, 270)
CPU times: user 2.69 s, sys: 40.6 ms, total: 2.73 s
Wall time: 2.57 s


In [24]:
early_marginal = np.sum(couplings['true'], 1)
late_marginal = np.sum(couplings['true'], 0)

In [25]:
def remove_unlisted_leaves(tree, kept_leaves, max_depth = 10):
    """
    Removes all leaves not listed in kept_leaves from the tree
    """
    
    for i in range(max_depth):
        all_leaves = lineageot.inference.get_leaves(tree, include_root = False)
        for leaf in all_leaves:
            if not leaf in kept_leaves:
                tree.remove_node(leaf)
                
                
    return 

#### LineageOT coupling

In [26]:
lineageOT_tree = copy.deepcopy(trees_by_batch[late_batch_index][0])




In [27]:
remove_unlisted_leaves(lineageOT_tree,
                       np.array(trees_by_batch[late_batch_index][1])[kept_late_cells_index], max_depth = 15)


In [28]:
lineageot.inference.add_nodes_at_time(lineageOT_tree, early_batch_time)
lineageot.inference.add_times_to_edges(lineageOT_tree)

<networkx.classes.digraph.DiGraph at 0x7f75d228ad60>

In [29]:
observed_nodes = trees_by_batch[late_batch_index][1]
lineageot.inference.add_conditional_means_and_variances(lineageOT_tree, observed_nodes)

In [30]:
ancestor_info = lineageot.inference.get_ancestor_data(lineageOT_tree, early_batch_time)


In [31]:
# Optimal transport coupling

In [32]:
lineage_rna_cost = ot.utils.dist(adata[early_cells].obsm['X_pca'],
                                 ancestor_info[0])@np.diag(ancestor_info[1]**(-1))
rna_cost = ot.utils.dist(adata[early_cells].obsm['X_pca'],
                         adata[late_cells].obsm['X_pca'])

late_time_rna_cost = ot.utils.dist(adata[late_cells].obsm['X_pca'],
                                   adata[late_cells].obsm['X_pca'])
# note early_time_rna_cost here is different from in the simulation evaluation because we don't have
# a single true ancestor for each late cell
early_time_rna_cost = ot.utils.dist(adata[early_cells].obsm['X_pca'],
                                    adata[early_cells].obsm['X_pca'])



lineage_rna_cost = lineage_rna_cost[kept_early_cells_index, :]
rna_cost = rna_cost[np.ix_(kept_early_cells_index, kept_late_cells_index)]
early_time_rna_cost = early_time_rna_cost[np.ix_(kept_early_cells_index, kept_early_cells_index)]
late_time_rna_cost = late_time_rna_cost[np.ix_(kept_late_cells_index, kept_late_cells_index)]

In [33]:
# checking that shapes match where they should
[x.shape for x in [lineage_rna_cost, rna_cost, early_time_rna_cost, late_time_rna_cost]]

[(740, 1024), (740, 1024), (740, 740), (1024, 1024)]

In [40]:
def prepare_tree(*, late: bool):
    if late:
        batch_index = late_batch_index
        kept_cells_index = kept_late_cells_index
    else:
        batch_index = early_batch_index
        kept_cells_index = kept_early_cells_index
    tree = copy.deepcopy(trees_by_batch[batch_index][0])
    remove_unlisted_leaves(tree,
                            np.array(trees_by_batch[batch_index][1])[kept_cells_index], max_depth = 15)
    n_added = lineageot.inference.add_nodes_at_time(tree, early_batch_time)
    if late:
        assert n_added > 0, n_added
    else:
        assert n_added == 0, n_added
    lineageot.inference.add_times_to_edges(tree)
    
    return tree


def compute_tree_distances(tree):
    """
    Computes the matrix of pairwise distances between leaves of the tree
    """
    leaves = lineageot.inference.get_leaves(tree, include_root=False)
    num_leaves = len(leaves)
    distances = np.zeros([num_leaves, num_leaves])
    for i, leaf in enumerate(leaves):
        distance_dictionary, tmp = nx.multi_source_dijkstra(tree.to_undirected(), [leaf], weight = 'time')
        for j, target_leaf in enumerate(leaves):
            distances[i, j] = distance_dictionary[target_leaf]
    return distances


def cost(cmat: np.array, *, late: bool) -> float:
    cmat = np.asarray(cmat)
    if late:
        cost = lineageot.inference.OT_cost(
            lineageot.evaluation.expand_coupling(cmat, couplings['true'], late_time_rna_cost), late_time_rna_cost
        ) / ind_late_cost
    else:
        cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(
            cmat.T, couplings['true'].T, early_time_rna_cost).T, early_time_rna_cost
        ) / ind_early_cost
        
    return cost


def create_data() -> data:
    etree = prepare_tree(late=False)
    ltree = prepare_tree(late=True)
    edist = compute_tree_distances(etree)
    ldist = compute_tree_distances(ltree)
    return data(edist, ldist, rna_cost, early_marginal, late_marginal)


def benchmark_moscot(d: data, *, alpha: float, epsilon: Optional[float] = None,
                     rescale: bool = True, **kwargs):
    def create_geometry(cost_matrix: np.ndarray) -> Geometry:
        cost_matrix = jnp.array(cost_matrix)
        if rescale:
            cost_matrix /= cost_matrix.max()
            assert cost_matrix.max() == 1.0
        assert (cost_matrix >= 0).all()
        return Geometry(cost_matrix=cost_matrix)
    
    e = create_geometry(d.edist)
    l = create_geometry(d.ldist)
    joint = create_geometry(d.rna_dist)
    
    max_iterations = kwargs.pop("max_iterations", 20)
    rtol = kwargs.pop("rtol", 1e-6)
    atol = kwargs.pop("atol", 1e-6)
    
    fgw = FusedGW(alpha=alpha, epsilon=epsilon, **kwargs)
    start = perf_counter()
    fgw.fit(e, l, joint,
            a=jnp.asarray(d.a), b=jnp.asarray(d.b),  # marginals
            linesearch=False, verbose=False, max_iterations=max_iterations, rtol=rtol, atol=atol)
    time = perf_counter() - start
    tmat = np.array(fgw.matrix)
    print(f"Time: {time}")
    
    early_cost = cost(tmat, late=False)
    late_cost = cost(tmat, late=True)
    norm_diff = np.linalg.norm(tmat - couplings['true'])
        
    return bnt(tmat, early_cost, late_cost, norm_diff, fgw.converged_sinkhorn, time)


def gridsearch(d: data, *, alphas: Sequence[float], epsilons: Sequence[float], rescale: bool = True, **kwargs) -> Dict[float, Dict[float, bnt]]:
    res = defaultdict(defaultdict)
        
    for alpha in alphas:
        for epsilon in epsilons:
            try:
                print(f"alpha={alpha}, epsilon={epsilon}")
                res[alpha][epsilon] = benchmark_moscot(d, alpha=alpha, epsilon=epsilon, rescale=rescale, **kwargs)
            except Exception as e:
                print(traceback.format_exc())
                res[alpha][epsilon] = None
    
    return {a: {e: v for e, v in vs.items()} for a, vs in res.items()}

In [35]:
d = create_data()

In [36]:
%%time
couplings['independent'] = np.outer(early_marginal, late_marginal)
ind_early_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['independent'].T, couplings['true'].T, early_time_rna_cost ).T
                                                 , early_time_rna_cost)
print(ind_early_cost)

ind_late_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['independent'], couplings['true'], late_time_rna_cost ).T
                                                 , late_time_rna_cost)
print(ind_late_cost)

96.81868794217223
113.41152063604068
CPU times: user 23 s, sys: 480 ms, total: 23.5 s
Wall time: 22.8 s


In [37]:
epsilons = [None, 1e-4, 5e-3, 1e-3, 5e-2] + list(np.logspace(-2, 3, 15))
alphas = list(np.linspace(0.025, 0.975, 20))
print(epsilons)
print(alphas)
len(epsilons), len(alphas)

[None, 0.0001, 0.005, 0.001, 0.05, 0.01, 0.022758459260747887, 0.05179474679231213, 0.11787686347935872, 0.2682695795279726, 0.6105402296585329, 1.3894954943731375, 3.1622776601683795, 7.196856730011521, 16.378937069540648, 37.27593720314942, 84.83428982440725, 193.06977288832496, 439.3970560760795, 1000.0]
[0.025, 0.075, 0.125, 0.175, 0.22499999999999998, 0.27499999999999997, 0.325, 0.375, 0.425, 0.475, 0.5249999999999999, 0.575, 0.625, 0.6749999999999999, 0.725, 0.7749999999999999, 0.825, 0.875, 0.9249999999999999, 0.975]


(20, 20)

In [41]:
res = gridsearch(d, alphas=alphas, epsilons=epsilons, max_iterations=100, rtol=1e-9, atol=1e-9)
with open(f"c_elegans.pickle", "wb") as fout:
    pickle.dump(res, fout)

alpha=0.025, epsilon=None
Time: 8.217234213021584


In [39]:
sentinel

NameError: name 'sentinel' is not defined

In [None]:
epsilons = np.array([0.1, 0.5,
                     1,
                     2,
                     2.5,
                     3, 5, 10, 20, 50, 100, 200, 300
                    ])*1


couplings['OT'] = ot.emd([],[],rna_cost)
couplings['lineageOT'] = ot.emd([], [], lineage_rna_cost)
for e in epsilons:
    print("Working on couplings for epsilon = " + str(e))
    if e < 1:
        f = ot.bregman.sinkhorn_epsilon_scaling
    else:
        f = ot.sinkhorn
    couplings['entropic rna ' + str(e)] = f(early_marginal,late_marginal,rna_cost, e)
    couplings['lineage entropic rna ' + str(e)] = f(early_marginal, late_marginal, lineage_rna_cost, e*np.mean(ancestor_info[1]**(-1)))


In [None]:
# Removing couplings computed incorrectly with numerical errors (likely from too small epsilon)
improper_couplings = [key for key in couplings if abs(np.sum(couplings[key])-1) > 0.01]

for key in improper_couplings:
    print(key)
    couplings[key] = np.nan*np.ones(couplings[key].shape)

### Evaluation of fitted couplings

In [None]:
def print_metrics(couplings, cost_func, cost_func_name, log = False):
    
    l = max([len(c) for c in couplings.keys()])
    print(cost_func_name)
    for c in couplings.keys():
        loss = cost_func(couplings[c])
        if log:
            loss = np.log(loss)
        print(c.ljust(l), ": ", "{:.3f}".format(loss))
    print("\n")
    return


def plot_metrics(couplings, cost_func, cost_func_name, epsilons, scale = 1, log = False):
    zero_offset = epsilons[0]/2
    all_ys = []
    if "lineageOT" in couplings.keys():
        ys = []
        for c, e in zip([couplings['lineage entropic rna ' + str(e)] for e in epsilons], epsilons):
            print("Working on cost function for LineageOT coupling with epsilon = " + str(e))
            ys.append(cost_func(c))
            print("Finished LineageOT coupling with epsilon = " + str(e))
            print("Cost: " + str(ys[-1]/scale) + "\n")
        plt.plot(epsilons, ys/scale, label = "LineageOT, true tree")
        all_ys.append(ys)
        #plt.scatter([zero_offset], [cost_func(couplings["lineageOT"])])
    if "OT" in couplings.keys():
        ys = []
        for c, e in zip([couplings['entropic rna ' + str(e)] for e in epsilons], epsilons):
            print("Working on cost function for OT coupling with epsilon = " + str(e))
            ys.append(cost_func(c))
            print("Finished OT coupling with epsilon = " + str(e))
            print("Cost: " + str(ys[-1]/scale) + "\n")
        plt.plot(epsilons, ys/scale, label = "Entropic OT")
        all_ys.append(ys)
        #plt.scatter([zero_offset], [cost_func(couplings["OT"])])

    
    plt.ylabel(cost_func_name)
    plt.xlabel("Entropy parameter")
    plt.xscale("log")
    plt.xlim([epsilons[0], epsilons[-1]])
    plt.ylim([0,None])
    plt.legend()
    return all_ys



In [None]:
%%time
ot_early_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['OT'].T, couplings['true'].T, early_time_rna_cost ).T
                                                 , early_time_rna_cost)
print(ot_early_cost)
ot_late_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['OT'], couplings['true'], late_time_rna_cost ).T
                                                 , late_time_rna_cost)
print(ot_late_cost)

In [None]:
%%time
lineageOT_early_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['lineageOT'].T, couplings['true'].T, early_time_rna_cost ).T
                                                 , early_time_rna_cost)
print(lineageOT_early_cost)
lineageOT_late_cost = lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(couplings['lineageOT'], couplings['true'], late_time_rna_cost ).T
                                                 , late_time_rna_cost)
print(lineageOT_late_cost)

In [None]:
%%time
late_time_errors = plot_metrics(couplings, lambda x:lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(x,
                                                                                             couplings['true'], 
                                                                                             late_time_rna_cost)
                                                 , late_time_rna_cost),
                                'Normalized descendant error', 
                                epsilons,
                                scale = ind_late_cost
                               )

In [None]:
%%time
early_time_errors = plot_metrics(couplings, lambda x:lineageot.inference.OT_cost(lineageot.evaluation.expand_coupling(x.T,
                                                                                             couplings['true'].T, 
                                                                                             early_time_rna_cost ).T
                                                 , early_time_rna_cost),
                                'Normalized ancestor error', 
                                epsilons,
                                scale = ind_early_cost
                               )

In [None]:
def plot_precomputed_metrics(all_ys, cost_func_name, epsilons, scale = 1, log = False, label_font_size = 18, tick_font_size = 12):
        
    plt.plot(epsilons, all_ys[0]/scale, label = "LineageOT, true tree")

    plt.plot(epsilons, all_ys[1]/scale, label = "Entropic OT")
        
    
    plt.ylabel(cost_func_name, fontsize=label_font_size)
    plt.xlabel("Entropy parameter", fontsize=label_font_size)
    plt.xscale("log")
    plt.xlim([epsilons[0], epsilons[-1]])
    plt.ylim([0,None])
    plt.legend(fontsize=tick_font_size)                 
    plt.xticks(fontsize=tick_font_size)  
    plt.yticks(fontsize=tick_font_size)  
                                          
    return 


In [None]:
if (early_batch_time, late_batch_time) == (210,270):
    plot_precomputed_metrics(early_time_errors, "Normalized ancestor error", epsilons, scale = ind_early_cost)
    plt.savefig(save_dir + "figure_3c.pdf", bbox_inches = "tight")

In [None]:
if (early_batch_time, late_batch_time) == (210,270):
    plot_precomputed_metrics(late_time_errors, "Normalized descendant error", epsilons, scale = ind_late_cost)
    plt.savefig(save_dir + "figure_3d.pdf", bbox_inches = "tight")

### Which cells does LineageOT predict better?

In [None]:
ancestor_errors = {}
descendant_errors = {}
label_font_size = 18
tick_font_size = 12

In [None]:
best_coupling_keys = ['lineage entropic rna 5.0', 'entropic rna 10.0']
for key in best_coupling_keys:
    ancestor_errors[key] = [ot.emd2(couplings[key][:, i]/np.sum(couplings[key][:, i]),
                                    couplings['true'][:, i]/np.sum(couplings['true'][:, i]),
                                    early_time_rna_cost)
                            for i in range(couplings['true'].shape[1])]
    descendant_errors[key] = [ot.emd2(couplings[key][i, :]/np.sum(couplings[key][i, :]),
                                      couplings['true'][i, :]/np.sum(couplings['true'][i, :]),
                                      late_time_rna_cost)
                            for i in range(couplings['true'].shape[0])]

In [None]:
xmax = np.max(descendant_errors[best_coupling_keys[1]])
plt.scatter(descendant_errors[best_coupling_keys[1]],
            descendant_errors[best_coupling_keys[0]],
            alpha = 0.2
           )
plt.plot([0,1+xmax], [0, 1+xmax], color = 'r')
plt.ylabel('LineageOT descendant error', fontsize=label_font_size)
plt.xlabel('OT descendant error', fontsize=label_font_size)

plt.xticks(fontsize=tick_font_size)
plt.yticks(fontsize=tick_font_size)

if (early_batch_time, late_batch_time) == (210,270):
    plt.savefig(save_dir + "figure_S5a.pdf", bbox_inches = "tight")

In [None]:
xmax = np.max(ancestor_errors[best_coupling_keys[1]])
plt.scatter(ancestor_errors[best_coupling_keys[1]],
            ancestor_errors[best_coupling_keys[0]],
            alpha = 0.2
           )
plt.plot([0,1+xmax], [0, 1+xmax], color = 'r')
plt.ylabel('LineageOT ancestor error', fontsize=label_font_size)
plt.xlabel('OT ancestor error', fontsize=label_font_size)

plt.xticks(fontsize=tick_font_size)
plt.yticks(fontsize=tick_font_size)

if (early_batch_time, late_batch_time) == (210,270):
    plt.savefig(save_dir + "figure_3b.pdf", bbox_inches = "tight")

#### Visualizing error

In [None]:
early_cells_to_plot = [early_cells[i] for i in kept_early_cells_index]
late_cells_to_plot = [late_cells[i] for i in kept_late_cells_index]
cells_to_plot = early_cells_to_plot + late_cells_to_plot
len(cells_to_plot)

In [None]:
adata_to_plot = adata[cells_to_plot]
sc.pp.neighbors(adata_to_plot)
sc.tl.umap(adata_to_plot)

In [None]:
adata_to_plot.obs['descendant error, lineageOT'] = pd.Series(descendant_errors[best_coupling_keys[0]], index=early_cells_to_plot)
adata_to_plot.obs['descendant error, OT'] = pd.Series(descendant_errors[best_coupling_keys[1]], index=early_cells_to_plot)
adata_to_plot.obs['descendant OT - lineageOT'] = pd.Series(np.array(descendant_errors[best_coupling_keys[1]]) - np.array(descendant_errors[best_coupling_keys[0]]),
                                                           index=early_cells_to_plot)

adata_to_plot.obs['ancestor error, lineageOT'] = pd.Series(ancestor_errors[best_coupling_keys[0]], index=late_cells_to_plot)
adata_to_plot.obs['ancestor error, OT'] = pd.Series(ancestor_errors[best_coupling_keys[1]], index=late_cells_to_plot)
adata_to_plot.obs['ancestor OT - lineageOT'] = pd.Series(np.array(ancestor_errors[best_coupling_keys[1]]) - np.array(ancestor_errors[best_coupling_keys[0]])
                                                         , index=late_cells_to_plot)


#normalizing for colormap
adata_to_plot.obs['descendant normalized'] = adata_to_plot.obs['descendant OT - lineageOT']/(2*np.abs(adata_to_plot.obs['descendant OT - lineageOT'].min())) + 1/2
adata_to_plot.obs['ancestor normalized'] = adata_to_plot.obs['ancestor OT - lineageOT']/(2*np.abs(adata_to_plot.obs['ancestor OT - lineageOT'].min())) + 1/2


In [None]:
pointsize = 10
label_font_size = 18
cmap = "coolwarm"
colors = [plt.get_cmap(cmap)(0), plt.get_cmap(cmap)(256)]

coord_key = 'X_umap'
plt.scatter(adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == early_batch_time].obsm[coord_key][:, 0],
            adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == early_batch_time].obsm[coord_key][:, 1],
            alpha = 1,
            s = pointsize,
            label = str(early_batch_time) + ' minutes',
            color = colors[0]
           )

plt.scatter(adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == late_batch_time].obsm[coord_key][:, 0],
            adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == late_batch_time].obsm[coord_key][:, 1],
            alpha = 1,
            s = pointsize,
            label = str(late_batch_time) + ' minutes',
            color = colors[1]
           )

plt.legend(markerscale = 2)

plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    left=False,
    top=False,         # ticks along the top edge are off
    right=False,
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off
plt.xlabel('UMAP 1', fontsize=label_font_size)
plt.ylabel('UMAP 2', fontsize=label_font_size)

if coord_key == 'X_umap':
    plt.savefig(save_dir + "figure_3e.pdf", bbox_inches = "tight")

plt.show()

In [None]:

error_cmap = 'seismic_r'
grey_alpha = 0.5
plt.scatter(adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == early_batch_time].obsm[coord_key][:, 0],
            adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == early_batch_time].obsm[coord_key][:, 1],
            alpha = grey_alpha,
            s = pointsize,
            label = str(early_batch_time) + ' minutes',
            color = 'gray'
           )

plt.scatter(adata_to_plot.obsm['X_umap'][:, 0],
            adata_to_plot.obsm['X_umap'][:, 1],
            c = adata_to_plot.obs['ancestor OT - lineageOT'],
            cmap=error_cmap,
            vmax = -adata_to_plot.obs['ancestor OT - lineageOT'].min(),
            vmin = adata_to_plot.obs['ancestor OT - lineageOT'].min(),
            s = pointsize
           )

cbar = plt.colorbar()
cbar.set_label("Ancestor error difference", fontsize=label_font_size - 2)
cbar.set_ticks([0])
plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    left=False,
    top=False,         # ticks along the top edge are off
    right=False,
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off
plt.xlabel('UMAP 1', fontsize=label_font_size)
plt.ylabel('UMAP 2', fontsize=label_font_size)
plt.savefig(save_dir + "figure_3f.pdf", bbox_inches = "tight")
plt.show()








plt.scatter(adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == late_batch_time].obsm[coord_key][:, 0],
            adata_to_plot[adata_to_plot.obs['assigned_batch_time'] == late_batch_time].obsm[coord_key][:, 1],
            alpha = grey_alpha,
            s = pointsize,
            label = str(late_batch_time) + ' minutes',
            color = 'gray'
           )

plt.scatter(adata_to_plot.obsm['X_umap'][:, 0],
            adata_to_plot.obsm['X_umap'][:, 1],
            c = adata_to_plot.obs['descendant OT - lineageOT'],
            cmap=error_cmap,
            vmax = -adata_to_plot.obs['descendant OT - lineageOT'].min(),
            vmin = adata_to_plot.obs['descendant OT - lineageOT'].min(),
            s = pointsize
           )
cbar = plt.colorbar()
cbar.set_label("Descendant error difference", fontsize=label_font_size - 2)
cbar.set_ticks([0])


plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    left=False,
    top=False,         # ticks along the top edge are off
    right=False,
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off

plt.xlabel('UMAP 1', fontsize=label_font_size)
plt.ylabel('UMAP 2', fontsize=label_font_size)

plt.savefig(save_dir + "figure_3g.pdf", bbox_inches = "tight")
plt.show()

In [None]:
print("Total time elapsed: ", datetime.datetime.now() - start_time)