In [1]:
import anndata
import numpy as np
import pandas as pd
import sys
sys.path.append('../../')
import velovae as vv

In [2]:
root = "/scratch/blaauw_root/blaauw1/gyichen"
data_path_base = f"{root}/data"
figure_path_base = f"{root}/figures"

# Datasets and Related Attributes

In [3]:
datasets = ["Pancreas", "Erythroid", "Braindev", "IPSC", "IO_EU"]
cluster_edges = {
    "Pancreas":[('Ngn3 low EP', 'Ngn3 high EP'), 
                ('Ngn3 high EP', 'Pre-endocrine'),
                ('Pre-endocrine', 'Delta'), 
                ('Pre-endocrine', 'Beta'), 
                ('Pre-endocrine','Epsilon'), 
                ('Pre-endocrine','Alpha')],
    "Erythroid":[("Blood progenitors 1", "Blood progenitors 2"), 
                 ("Blood progenitors 2", "Erythroid1"),
                 ("Erythroid1", "Erythroid2"), 
                 ("Erythroid2", "Erythroid3")],
    "IPSC":[('MET','Epithelial'),
            ('Epithelial','IPS'),
            ('Epithelial','Neural'),
            ('Epithelial','Trophoblast'),
            ('Epithelial','Stromal')],
    "BMMC":[('HSC','LMPP'),
            ('LMPP','GMP'),
            ('GMP','CD14 Mono'),
            ('CD14 Mono','CD16 Mono'),
            ('Prog DC','cDc2'),
            ('Prog B 1','Prog B 2'), 
            ('Prog MK','Prog RBC')],
    "Braindev":[('Neural tube','Radial glia'),
                ('Radial glia', 'Neuroblast'),
                ('Radial glia', 'Glioblast'),
                ('Radial glia', 'Oligodendrocyte'),
                ('Radial glia', 'Ependymal'),
                ('Neural crest', 'Mesenchyme'),
                ('Mesenchyme','Fibroblast')],
    "HIO":[],
    "Bonemarrow":[('HSC_1', 'Ery_1'), 
                 ('HSC_1', 'HSC_2'), 
                 ('Ery_1', 'Ery_2')],
    "Dentategyrus":[('OPC', 'OL')],
    "Erythroid_Human":[('MEMP', 'Early Erythroid'), 
                       ('Early Erythroid', 'Mid  Erythroid'), 
                       ('Mid  Erythroid', 'Late Erythroid')],
    "Hindbrain_pons":[('COPs', 'NFOLs'), 
                      ('NFOLs', 'MFOLs')],
    "IO_EU":[('Stem cells', 'TA cells'), 
             ('Stem cells', 'Goblet cells'),
             ('Stem cells', 'Tuft cells'), 
             ('TA cells', 'Enterocytes')],
    "Neuron_scNT":[('0' , '15'), 
                   ('15', '30'), 
                   ('30', '60'), 
                   ('60', '120')],
    "Retina":[('Neuroblast', 'PR'), 
              ('Neuroblast', 'AC/HC'), 
              ('Neuroblast', 'RGC')]
}

genes = {
    "Pancreas": ["Gng12", "Smoc1", "Ppp3ca", "Nnat"],
    "Erythroid": ["Smim1", "Blvrb", "Hba-x", "Lmo2"],
    "Braindev": ['Mapt', 'Tmsb10', 'Fabp7', 'Npm1'],
    "IPSC": ["Vim","Nr2f1","Krt7","H19"],
    "BMMC": ['SPINK2', 'AZU1', 'MPO', 'LYZ', 'CD74', 'HBB'],
    "HIO": ['PLP1','ECSCR', 'COL1A1', 'EPCAM'],
    "Bonemarrow": ['CD44','CELF2','TAOK3'],
    "Dentatgyrus": ['Tmsb10', 'Fam155a', 'Hn1', 'Rpl6'],
    "Erythroid_Human": ['CNN3','CYR61','ABCG2','HBA1'],
    "Hindbrain_pons": ['Ptprz1','Enpp6','Rras2','Mal'],
    "IO_EU": ["Apoa1","Dgat1","Gsta4","Lgr5"],
    "Neuron_scNT": ['Fosb','Rfx3','Cebpg','Homer1'],
    "Retina": ['Mcm6','Cdk1','Esco2','Cenpa'],
}

# RNA Velocity Methods

In [None]:
methods = ["scVelo", "VeloVAE", "FullVB", "UniTVelo", "DeepVelo", "VeloVI"]
keys = ["fit", "velovae", "fullvb", "utv", "dv", "velovi"]
out_folders = [f"{data_path_base}/scvelo"
               f"{data_path_base}/velovae/continuous",
               f"{data_path_base}/velovae/continuous",
               f"{data_path_base}/utv",
               f"{data_path_base}/deepvelo",
               f"{data_path_base}/velovi"]

# Performance Comparison

In [None]:
res_final = {}
for i, dataset in enumerate(datasets):
    figure_path = f"figures/{dataset}"
    res_list = []
    for j, method in enumerate(methods):
        adata = anndata.read_h5ad(f"{out_folders[j]}/{dataset}/{dataset}.h5ad")
        if(method=="VeloVI"):
            adata.uns["velovi_train_idx"] = adata.uns["vi_train_idx"]
            adata.uns["velovi_test_idx"] = adata.uns["vi_test_idx"]
        res = vv.post_analysis(adata,
                               method,
                               [methods[j]],
                               [keys[j]],
                               compute_metrics=True,
                               genes=genes[dataset],
                               plot_type=['all'],
                               cluster_edges=cluster_edges[dataset],
                               save_path=figure_path)
        res_list.append(res)
    res_final[dataset] = pd.concat(res_list, axis=1).T
    print(res_final[dataset])

In [None]:
def fix_na(df):
    for method in df.columns:
        for dataset in df.index:
            if(df.loc[dataset, method]=='N/A'):
                df.loc[dataset, method] = np.nan
    return df

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2,2,figsize=(24,12),facecolor='white')
colors = vv.plotting.get_colors(len(datasets))

i, j = 0, 0
for metric in ['Cross-Boundary Direction Correctness', 'Cross-Boundary Direction Correctness (embed)', 'In-Cluster Coherence', 'corr']:
    
    if(metric=='corr'):
        res_plot = pd.concat([res_final["Erythroid"][metric], \
                              res_final["Braindev"][metric], 
                              res_final["IPSC"][metric]], axis=1)
        res_plot.columns = ["Erythroid", "Braindev", "IPSC"]
    else:
        res_plot = pd.concat([res_final[key][metric] for key in res_final], axis=1)
        res_plot.columns = datasets
    res_plot = res_plot.T
    res_plot = fix_na(res_plot)
    res_plot.plot.bar(ax=ax[i,j],rot=0, legend=False)
    ax[i,j].set_title(metric, fontsize=20)
    j = j+1
    if(j==2):
        j = 0
        i += 1
handles, labels = ax[0,0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, markerscale=1, bbox_to_anchor=(1.0,1.0), loc='upper left')
plt.tight_layout()
fig.savefig('figures/compare_vel.png', bbox_extra_artist=(lgd,), bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(2,3,figsize=(24,12),facecolor='white')
colors = vv.plotting.get_colors(len(datasets))

i, j = 0, 0
for metric in ['MSE Train', 'MSE Test', 'MAE Train', 'MAE Test', 'LL Train', 'LL Test']:
    res_plot = pd.concat([res_final[key][metric] for key in res_final], axis=1)
    res_plot.columns = datasets
    res_plot = res_plot.T
    res_plot = fix_na(res_plot)
    res_plot.plot.bar(ax=ax[i,j],rot=0, legend=False)
    ax[i,j].set_title(metric, fontsize=20)
    if('MSE' in metric or 'MAE' in metric):
        ax[i,j].set_yscale('log')
    else:
        ax[i,j].set_ylim(-3000, 5000)
    j = j+1
    if(j==3):
        j = 0
        i += 1
handles, labels = ax[0,0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, fontsize=20, markerscale=1, bbox_to_anchor=(1.0,1.0), loc='upper left')
plt.tight_layout()
fig.savefig('figures/compare_err.png', bbox_extra_artist=(lgd,), bbox_inches='tight')    