In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import pandas as pd
import anndata
import numpy as np
import metrics
from tqdm import tqdm
import pickle
import scipy
import os
from cellflow.metrics import compute_r_squared, compute_e_distance, compute_scalar_mmd

  from optuna import progress_bar as pbar_module


In [2]:
def save_with_pickle(m, path):
    with open(path, 'wb') as file:
        pickle.dump(m, file)

In [3]:
Path = '/lustre/groups/ml01/projects/CellOT_comparison/cellflow/data/4i/'
os.chdir(Path)


In [4]:
drugs = pd.read_csv('../../data/4i/drugs.txt', header=None)[0].values
drugs = [a[:-1] if a!='vindesine' else a for a in drugs]
adata = sc.read_h5ad('../../data/4i/8h.h5ad')

In [5]:
adata = sc.read_h5ad("/lustre/groups/ml01/projects/CellOT_comparison/adata_processed.h5ad")
    

In [6]:
adata.obs["drug"].unique()

['ixazomib', 'everolimus', 'olaparib', 'paclitaxel', 'vemurafenib_cobimetinib', ..., 'hydroxyurea', 'ixazomib_lenalidomide_dexamethasone', 'melphalan', 'ulixertinib', 'dasatinib']
Length: 36
Categories (36, object): ['cisplatin', 'cisplatin_olaparib', 'control', 'crizotinib', ..., 'trametinib_panobinostat', 'ulixertinib', 'vemurafenib_cobimetinib', 'vindesine']

In [7]:
def compute_sinkhorn_div(x: np.ndarray, y: np.ndarray, epsilon: float) -> float:
    from ott.tools.sinkhorn_divergence import sinkhorn_divergence
    from ott.geometry import costs, pointcloud
    return float(sinkhorn_divergence(
            pointcloud.PointCloud,
            x=x,
            y=y,
            cost_fn=costs.SqEuclidean(),
            epsilon=epsilon,
            scale_cost=1.0,
        )[0])

In [8]:
dfs = []
for drug in tqdm(drugs):
    Tests = {}
    Imputations = {}
    
    df = pd.read_csv(f'../../data/splits/{drug}.csv')
    train_cells = list(df[df['split']=='train']['Unnamed: 0'])
    test_cells = list(df[df['split']=='test']['Unnamed: 0'])
    adata_train = adata[adata.obs.index.isin(train_cells)].copy()
    adata_test = adata[adata.obs.index.isin(test_cells)].copy()
    
    true = adata_test[adata_test.obs["drug"]==drug].X
    prediction = adata_test[adata_test.obs['drug']=='control'].X
    r_squared = compute_r_squared(true, prediction)
    e_distance = compute_e_distance(true, prediction)
    mmd = compute_scalar_mmd(true, prediction)
    sinkhorn_div_1 = compute_sinkhorn_div(true, prediction, 1.0)
    sinkhorn_div_10 = compute_sinkhorn_div(true, prediction, 10.0)
    sinkhorn_div_100 = compute_sinkhorn_div(true, prediction, 100.0)
    metrics = {
        "r_squared": r_squared,
        "e_distance": e_distance,
        "mmd": mmd,
        "sinkhorn_div_1": sinkhorn_div_1,
        "sinkhorn_div_10": sinkhorn_div_10,
        "sinkhorn_div_100": sinkhorn_div_100
    }
    metrics_df = pd.DataFrame(metrics, columns=["values"])
    metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[drug]).T
    metrics_df["drug"] = metrics_df.index
    metrics_df["method"] = "Identity"
    dfs.append(metrics_df)

df_id = pd.concat(dfs)
df_id.to_csv("/lustre/groups/ml01/projects/CellOT_comparison/results_id.csv")

  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((se

In [9]:
for drug in tqdm(drugs):

    df = pd.read_csv(f'../../data/splits/{drug}.csv')
    train_cells = list(df[df['split']=='train']['Unnamed: 0'])
    test_cells = list(df[df['split']=='test']['Unnamed: 0'])
    adata_train = adata[adata.obs.index.isin(train_cells)].copy()
    adata_test = adata[adata.obs.index.isin(test_cells)].copy()
    
    prediction = sc.read_h5ad(f'/lustre/groups/ml01/projects/CellOT_comparison/cellflow/data/scgen_preds/imputed_{drug}.h5ad').X
    true = adata_test[adata_test.obs["drug"]==drug].X
    r_squared = compute_r_squared(true, prediction)
    e_distance = compute_e_distance(true, prediction)
    mmd = compute_scalar_mmd(true, prediction)
    sinkhorn_div_1 = compute_sinkhorn_div(true, prediction, 1.0)
    sinkhorn_div_10 = compute_sinkhorn_div(true, prediction, 10.0)
    sinkhorn_div_100 = compute_sinkhorn_div(true, prediction, 100.0)
    metrics = {
        "r_squared": r_squared,
        "e_distance": e_distance,
        "mmd": mmd,
        "sinkhorn_div_1": sinkhorn_div_1,
        "sinkhorn_div_10": sinkhorn_div_10,
        "sinkhorn_div_100": sinkhorn_div_100
    }
    metrics_df = pd.DataFrame(metrics, columns=["values"])
    metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[drug]).T
    metrics_df["drug"] = metrics_df.index
    metrics_df["method"] = "scGen"
    dfs.append(metrics_df)

df_scgen = pd.concat(dfs)
df_scgen.to_csv("/lustre/groups/ml01/projects/CellOT_comparison/results_scgen.csv")

  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((se

  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
100%|███████████████████

In [13]:
# CellOT

for drug in tqdm(drugs):
    Tests = {}
    Imputations = {}
    
    df = pd.read_csv(f'../../data/splits/{drug}.csv')
    train_cells = list(df[df['split']=='train']['Unnamed: 0'])
    test_cells = list(df[df['split']=='test']['Unnamed: 0'])
    adata_train = adata[adata.obs.index.isin(train_cells)].copy()
    adata_test = adata[adata.obs.index.isin(test_cells)].copy()
    
    prediction = sc.read_h5ad(f'/lustre/groups/ml01/projects/CellOT_comparison/cellot-main/results/4i/drug-{drug}/model-cellot/evals_iid_data_space/imputed.h5ad').X
    true = adata_test[adata_test.obs["drug"]==drug].X
    r_squared = compute_r_squared(true, prediction)
    e_distance = compute_e_distance(true, prediction)
    mmd = compute_scalar_mmd(true, prediction)
    sinkhorn_div_1 = compute_sinkhorn_div(true, prediction, 1.0)
    sinkhorn_div_10 = compute_sinkhorn_div(true, prediction, 10.0)
    sinkhorn_div_100 = compute_sinkhorn_div(true, prediction, 100.0)
    metrics = {
        "r_squared": r_squared,
        "e_distance": e_distance,
        "mmd": mmd,
        "sinkhorn_div_1": sinkhorn_div_1,
        "sinkhorn_div_10": sinkhorn_div_10,
        "sinkhorn_div_100": sinkhorn_div_100
    }
    metrics_df = pd.DataFrame(metrics, columns=["values"])
    metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[drug]).T
    metrics_df["drug"] = metrics_df.index
    metrics_df["method"] = "CellOT"
    dfs.append(metrics_df)

df_cellot = pd.concat(dfs)
df_cellot.to_csv("/lustre/groups/ml01/projects/CellOT_comparison/results_cellot.csv")

  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
  lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
  errors = -jnp.ones((se