## Demonstration metrics

Import packages

In [107]:
import scanpy as sc
from scipy import sparse
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score
import mudata as mu
from muon import atac as ac

from celldreamer.paths import DATA_DIR

from celldreamer.eval.compute_evaluation_metrics import process_labels, compute_evaluation_metrics
from celldreamer.eval.distribution_distances import train_knn_real_data
from celldreamer.eval.compute_evaluation_metrics import process_labels, compute_evaluation_metrics

In [108]:
DATA_DIR

PosixPath('/ictstr01/home/icb/alessandro.palma/environment/cfgen/project_folder/datasets')

Util function

In [109]:
def add_to_dict(d, metrics):
    for metric in metrics:
        if metric not in d:
            d[metric] = [metrics[metric]]
        else:
            d[metric]+=[metrics[metric]]
    return d

Collect metrics

In [110]:
results_celldreamer_atac = {}
results_celldreamer_rna = {}

Read dataset 

In [111]:
adata_real = mu.read(DATA_DIR / "processed/atac/pbmc/pbmc10k_multiome_test.h5mu")



In [112]:
knn_pca_rna = train_knn_real_data(adata_real_rna, "cell_type", use_pca=True, n_neighbors=30)
knn_data_rna = train_knn_real_data(adata_real_rna, "cell_type", use_pca=False, n_neighbors=30)
knn_pca_atac = train_knn_real_data(adata_real_atac, "cell_type", use_pca=True, n_neighbors=30)
knn_data_atac = train_knn_real_data(adata_real_atac, "cell_type", use_pca=False, n_neighbors=30)

Preprocess

In [113]:
adata_real_rna = adata_real["rna"]
# Bring back counts 
adata_real_rna.X = adata_real_rna.layers["X_counts"].copy()
# Compute HVG (don't subset)
sc.pp.highly_variable_genes(adata_real_rna,
                            flavor="seurat_v3",
                            n_top_genes=2000,
                            layer="X_counts",
                            subset=False)

# Pick 30 pcs
sc.pp.normalize_total(adata_real_rna, target_sum=1e4)
sc.pp.log1p(adata_real_rna)
sc.tl.pca(adata_real_rna, n_comps=30)

In [114]:
# ATAC 
adata_real_atac = adata_real["atac"]
# Bring back counts 
adata_real_atac.obs["cell_type"] = adata_real_rna.obs["cell_type"]  # Harmonize annotation
adata_real_atac.X = adata_real_atac.layers["X_counts"].copy()
ac.pp.tfidf(adata_real_atac, scale_factor=1e4)
# Compute highly variable peaks 
sc.pp.highly_variable_genes(adata_real_atac, n_top_genes=10000, subset=False)
sc.tl.pca(adata_real_rna, n_comps=30)

In [115]:
celltype_unique = np.unique(adata_real_rna.obs["cell_type"])  # unique cell type 
adata_real_rna = adata_real_rna[:, adata_real_rna.var.highly_variable]
adata_real_atac = adata_real_atac[:, adata_real_atac.var.highly_variable]

In [116]:
adata_generated_path_celldreamer_rna = DATA_DIR / "generated/pbmc10k_multimodal/generated_cells_0_rna.h5ad"
adata_generated_celldreamer_rna = sc.read_h5ad(adata_generated_path_celldreamer_rna)
adata_generated_celldreamer_rna.var = vars_rna
adata_generated_celldreamer_rna = adata_generated_celldreamer_rna[:, adata_generated_celldreamer_rna.var.highly_variable]
adata_generated_celldreamer_rna.obsm["X_pca"] = adata_generated_celldreamer_rna.X.A.dot(adata_real_rna.varm["PCs"])

  adata_generated_celldreamer_rna.obsm["X_pca"] = adata_generated_celldreamer_rna.X.A.dot(adata_real_rna.varm["PCs"])


In [117]:
adata_generated_path_celldreamer_atac = DATA_DIR / "generated/pbmc10k_multimodal/generated_cells_0_atac.h5ad"
adata_generated_celldreamer_atac = sc.read_h5ad(adata_generated_path_celldreamer_atac)
adata_generated_celldreamer_atac.var = vars_atac
ac.pp.tfidf(adata_generated_celldreamer_atac, scale_factor=1e4)
adata_generated_celldreamer_atac = adata_generated_celldreamer_atac[:, adata_generated_celldreamer_atac.var.highly_variable]
adata_generated_celldreamer_atac.obsm["X_pca"] = adata_generated_celldreamer_atac.X.A.dot(adata_real_atac.varm["PCs"])

  adata_generated_celldreamer_atac.obsm["X_pca"] = adata_generated_celldreamer_atac.X.A.dot(adata_real_atac.varm["PCs"])


In [118]:
for ct in celltype_unique:
    adata_real_ct_atac = adata_real_atac[adata_real_atac.obs["cell_type"]==ct]
    adata_real_ct_rna = adata_real_rna[adata_real_rna.obs["cell_type"]==ct]
    adata_generated_celldreamer_rna_ct = adata_generated_celldreamer_rna[adata_generated_celldreamer_rna.obs["cell_type"]==ct]
    adata_generated_celldreamer_atac_ct = adata_generated_celldreamer_atac[adata_generated_celldreamer_atac.obs["cell_type"]==ct]
    results_celldreamer_rna_ct = compute_evaluation_metrics(adata_real_ct_rna, 
                                                            adata_generated_celldreamer_rna_ct, 
                                                            "cell_type",
                                                            "celldreamer_rna",
                                                            nn=10, 
                                                            original_space=True, 
                                                            knn_pca=knn_pca_rna, 
                                                            knn_data=knn_data_rna)

    results_celldreamer_atac_ct = compute_evaluation_metrics(adata_real_ct_atac, 
                                                                adata_generated_celldreamer_atac_ct,
                                                                "cell_type", 
                                                                "celldreamer_atac",
                                                                nn=10, 
                                                                original_space=True,
                                                                knn_pca=knn_pca_atac, 
                                                                knn_data=knn_data_atac)

    results_celldreamer_rna_ct["ct"] = ct
    results_celldreamer_atac_ct["ct"] = ct
    results_celldreamer_rna = add_to_dict(results_celldreamer_rna, results_celldreamer_rna_ct)
    results_celldreamer_atac = add_to_dict(results_celldreamer_atac, results_celldreamer_atac_ct)

Evaluating for celldreamer_rna
Real (79, 2000)
Generated (83, 2000)
Evaluating for celldreamer_atac
Real (79, 10000)
Generated (83, 10000)
Evaluating for celldreamer_rna
Real (97, 2000)
Generated (83, 2000)
Evaluating for celldreamer_atac
Real (97, 10000)
Generated (83, 10000)
Evaluating for celldreamer_rna
Real (21, 2000)
Generated (22, 2000)
Evaluating for celldreamer_atac
Real (21, 10000)
Generated (22, 10000)
Evaluating for celldreamer_rna
Real (366, 2000)
Generated (358, 2000)
Evaluating for celldreamer_atac
Real (366, 10000)
Generated (358, 10000)
Evaluating for celldreamer_rna
Real (84, 2000)
Generated (93, 2000)
Evaluating for celldreamer_atac
Real (84, 10000)
Generated (93, 10000)
Evaluating for celldreamer_rna
Real (151, 2000)
Generated (158, 2000)
Evaluating for celldreamer_atac
Real (151, 10000)
Generated (158, 10000)
Evaluating for celldreamer_rna
Real (105, 2000)
Generated (106, 2000)
Evaluating for celldreamer_atac
Real (105, 10000)
Generated (106, 10000)
Evaluating for 

In [119]:
results_celldreamer_rna_df = pd.DataFrame(results_celldreamer_rna)
results_celldreamer_atac_df = pd.DataFrame(results_celldreamer_atac)

In [120]:
results_celldreamer_rna_df.groupby("ct").mean()

Unnamed: 0_level_0,1-Wasserstein_PCA,2-Wasserstein_PCA,Linear_MMD_PCA,Poly_MMD_PCA,RBF_MMD_PCA,KNN identity,KNN identity PCA,precision,recall,density,coverage,precision_PCA,recall_PCA,density_PCA,coverage_PCA,KNN category,KNN category PCA
ct,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
CD56 (bright) NK cells,13.901233,13.964469,164.203293,56601.136719,0.9492,0.795353,1.0,0.46988,0.772152,0.33253,0.936709,0.0,0.0,0.0,0.0,0.27972,0.322981
CD56 (dim) NK cells,14.047998,14.101661,170.08522,41730.402344,0.944416,0.777559,1.0,0.457831,0.845361,0.312048,0.907216,0.0,0.0,0.0,0.0,0.201439,0.457516
MAIT T cells,14.268649,14.293301,160.174957,29739.082031,1.135221,0.790244,1.0,0.818182,0.952381,0.931818,1.0,0.0,0.0,0.0,0.0,0.0,0.021739
classical monocytes,13.616992,13.694651,153.14386,192156.390625,0.717225,0.694954,1.0,0.413408,0.789617,0.21257,0.754098,0.0,0.0,0.0,0.0,0.18206,0.478894
effector CD8 T cells,14.366617,14.430753,166.115646,28716.902344,0.862769,0.738728,1.0,0.451613,0.702381,0.41828,0.988095,0.0,0.0,0.0,0.0,0.120301,0.216463
intermediate monocytes,13.538697,13.618433,152.127472,190892.796875,0.761672,0.401392,1.0,0.139241,0.960265,0.031646,0.211921,0.0,0.0,0.0,0.0,0.312009,0.475083
memory B cells,13.775068,13.825666,162.348572,72655.289062,0.88761,0.401753,1.0,0.075472,0.942857,0.032075,0.257143,0.0,0.0,0.0,0.0,1.0,0.495238
memory CD4 T cells,14.195562,14.251471,172.443268,24748.679688,0.847929,0.76538,1.0,0.157738,0.801205,0.132738,0.689759,0.0,0.0,0.0,0.0,0.311111,1.0
myeloid DC,13.217733,13.346975,129.637451,138912.0625,0.776389,0.355072,1.0,0.225,0.959184,0.055,0.367347,0.0,0.0,0.0,0.0,0.306306,0.236559
naive B cells,13.994476,14.031326,165.514069,78944.445312,0.992338,0.445567,1.0,0.169492,0.927273,0.044068,0.363636,0.0,0.0,0.0,0.0,0.142222,0.385417


In [121]:
results_celldreamer_atac_df.groupby("ct").mean()

Unnamed: 0_level_0,1-Wasserstein_PCA,2-Wasserstein_PCA,Linear_MMD_PCA,Poly_MMD_PCA,RBF_MMD_PCA,KNN identity,KNN identity PCA,precision,recall,density,coverage,precision_PCA,recall_PCA,density_PCA,coverage_PCA,KNN category,KNN category PCA
ct,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
CD56 (bright) NK cells,17.549646,18.153809,141.890717,46748.230469,1.115963,0.338776,1.0,0.096386,0.177215,0.261446,1.0,0.0,0.0,0.0,0.0,0.0,0.023529
CD56 (dim) NK cells,17.073234,17.71109,166.354004,44632.304688,1.132415,0.370725,1.0,0.204819,0.154639,1.118072,1.0,0.975904,0.0,0.09759,0.010309,0.007937,0.0
MAIT T cells,17.802936,18.763143,95.825935,25358.958984,1.255773,0.338462,1.0,0.545455,0.619048,0.986364,1.0,1.0,0.0,0.622727,0.47619,0.0,0.0
classical monocytes,18.084646,19.047541,136.628464,77019.96875,0.757537,0.348843,1.0,0.00838,0.07377,0.056983,0.434426,0.907821,0.0,0.102793,0.032787,0.0,0.495063
effector CD8 T cells,23.598734,25.774977,227.464279,290929.0,0.998872,0.755203,1.0,0.064516,0.214286,0.225806,1.0,1.0,0.0,1.234409,0.25,0.0,0.0
intermediate monocytes,19.016929,20.085134,143.168152,87880.015625,0.789143,0.394541,1.0,0.012658,0.238411,0.096203,1.0,1.0,0.0,0.262658,0.086093,0.0,0.238482
memory B cells,16.921557,17.679772,123.306808,119630.023438,0.84078,0.815149,0.985777,0.132075,0.352381,0.749057,1.0,0.943396,0.0,0.570755,0.219048,0.0,0.495238
memory CD4 T cells,17.051328,18.060486,98.650558,40413.828125,0.826709,0.332,1.0,0.005952,0.13253,0.179762,1.0,0.970238,0.0,0.365179,0.045181,0.497006,0.487023
myeloid DC,18.982511,20.158217,18.45129,160260.0,0.793022,0.310078,0.966138,0.45,0.571429,0.81,1.0,1.0,0.0,1.0,0.714286,0.0,0.25
naive B cells,15.08401,15.625309,89.797333,58658.367188,0.860073,0.760037,1.0,0.186441,0.527273,0.689831,1.0,1.0,0.0,0.240678,0.109091,0.0,0.22963
