In [1]:
import mudata as mu
from deconvatac.tl import tangram
import pandas as pd
import numpy as np
import seml
import pandas as pd
import glob
import deconvatac as de
import seaborn as sns
import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


## Run Tangram

In [2]:
class ExperimentWrapper:
    """
    A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes.
    This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by
    specific method. This avoids having one large "main" function which takes all parameters as input.
    """

    def __init__(self, init_all=True):
        if init_all:
            self.init_all()

    def init_dataset(self, mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality):

        self.spatial_path = mdata_spatial_path
        self.adata_spatial = mu.read_h5mu(mdata_spatial_path).mod[modality]
        self.adata_reference = mu.read_h5mu(mdata_reference_path).mod[modality]
        # subset on HVFs
        self.adata_spatial = self.adata_spatial[:, self.adata_reference.var[var_HVF_column]]
        self.adata_reference = self.adata_reference[:, self.adata_reference.var[var_HVF_column]]

        self.modality = modality
        self.labels_key = labels_key
        self.var_HVF_column = var_HVF_column

    def init_method(self, method_id):
        self.method_id = method_id

    def init_all(self):
        self.init_dataset()
        self.init_method()

    def run(self, output_path):

        dataset = self.spatial_path.split("/")[-1].split(".")[0]
        dataset_var_column = dataset + "_" + self.var_HVF_column
        output_path = output_path + self.modality + "/" + dataset_var_column

        tangram(
            adata_spatial=self.adata_spatial,
            adata_ref=self.adata_reference,
            labels_key=self.labels_key,
            run_rank_genes=False,
            result_path=output_path,
            device="cuda:0",
            num_epochs=1000,
        )

        results = {
            "result_path": output_path + "/tangram_ct_pred.csv",
            "dataset": dataset,
            "modality": self.modality,
            "var_HVF_column": self.var_HVF_column,
        }
        return results

### Create new reference


In [3]:
def create_reference(parquet, ref_path): 
    sample_cells = pd.read_parquet(parquet)
    ref = mu.read_h5mu("/vol/storage/data/cellxgene/human_cardiac_niches/human_cardiac_niches.h5mu")
    cell_ids = np.concatenate(sample_cells['cell_id'].values)
    ref = ref[np.unique(cell_ids)]
    print(ref)
    ref.write(ref_path)

In [4]:
parquets = ["/vol/storage/data/simulations/test/Heart_1.pq", "/vol/storage/data/simulations/test/Heart_2.pq", 
            "/vol/storage/data/simulations/test/Heart_3.pq", "/vol/storage/data/simulations/test/Heart_4.pq"]
ref_paths = ["/vol/storage/data/simulations/test_ref/Heart1_ref.h5mu", "/vol/storage/data/simulations/test_ref/Heart2_ref.h5mu"
             , "/vol/storage/data/simulations/test_ref/Heart3_ref.h5mu", "/vol/storage/data/simulations/test_ref/Heart4_ref.h5mu"]

In [5]:
for i in range(len(parquets)):
    if i == 0: 
        print(i)
        create_reference(parquets[i], ref_paths[i])

0


  if not is_categorical_dtype(df_full[k]):


View of MuData object with n_obs × n_vars = 1665 × 462560
  var:	'highly_variable'
  2 modalities
    atac:	1665 x 429828
      obs:	'sangerID', 'combinedID', 'donor', 'donor_type', 'region', 'region_finest', 'age', 'gender', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'flushed', 'batch_key', 'cell_type', 'cell_state'
      var:	'highly_variable', 'highly_accessible'
      uns:	'log1p'
      layers:	'log_norm', 'tfidf_normalized'
    rna:	1665 x 32732
      obs:	'sangerID', 'combinedID', 'donor', 'donor_type', 'region', 'region_finest', 'age', 'gender', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'flushed', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'scrublet_score', 'scrublet_leiden', 'cluster_scrublet_score', 'doublet_pval', 'doublet_bh_pval', 'batch_key', 'leiden_scVI', 'cell_type', 'cell_state_HCAv1', 'cell_state_scNym', 'cell_state_scNym_confidence', 'cell_state', 'latent_RT_efficien

### Run Tangram

In [6]:
def run_tangram(ref_path, spatial_path, modality): 
    mdata_reference_path = ref_path
    mdata_spatial_path = spatial_path
    method_id =  "Tangram"
    output_path =  "/vol/storage/data/deconvolution_results/test2/"
    labels_key = "cell_type"
    modality = modality
    var_HVF_column = "highly_variable"
    ex = ExperimentWrapper(init_all=False)
    ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
    ex.init_method(method_id)
    ex.run(output_path)
    if modality == "atac": 
        var_HVF_column = "highly_accessible"
        ex = ExperimentWrapper(init_all=False)
        ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
        ex.init_method(method_id)
        ex.run(output_path)



In [7]:
spatial_paths = ["/vol/storage/data/simulations/test/Heart_1.h5mu", "/vol/storage/data/simulations/test/Heart_2.h5mu", 
                 "/vol/storage/data/simulations/test/Heart_3.h5mu", "/vol/storage/data/simulations/test/Heart_4.h5mu"]

In [None]:
for i in range(len(spatial_paths)): 
    print(i)
    run_tangram(ref_paths[i], spatial_paths[i], "rna")
    run_tangram(ref_paths[i], spatial_paths[i], "atac")


## Evaluate the results

In [None]:
data_path = "/vol/storage/data/deconvolution_results/test2"
methods = ["tangram"]
modalities = ["atac", "rna"]

In [None]:
# go over all methods and modalities
df = [pd.DataFrame({'path': glob.glob(os.path.join(data_path, method, modality, "*", "*"))}) for method in methods for modality in modalities]
df = pd.concat(df)
df

In [None]:
df[['method', 'modality', 'dataset_features']] = df['path'].str.split('/', expand=True).iloc[:, 5:-1]
df['dataset'] = df['dataset_features'].str.rsplit("_", n=2).str[0]
df["features"] = df["dataset_features"].str.split("_", n=2).str[-1]
df

In [None]:
mapping_dict = {
    "russell_250": "/vol/storage/data/simulations/test/russell_250.h5mu",
    "Heart_1": "/vol/storage/data/simulations/test/Heart_1.h5mu",
    "Heart_2": "/vol/storage/data/simulations/test/Heart_2.h5mu",
    "Heart_3": "/vol/storage/data/simulations/test/Heart_3.h5mu",
    "Heart_4": "/vol/storage/data/simulations/test/Heart_4.h5mu",
    "Brain_1": "/vol/storage/data/simulations/test/Brain_1.h5mu",
    "Brain_2": "/vol/storage/data/simulations/test/Brain_2.h5mu",
    "Brain_3": "/vol/storage/data/simulations/test/Brain_3.h5mu",
    "Brain_4": "/vol/storage/data/simulations/test/Brain_4.h5mu",
}

In [147]:
df["mdata_spatial_path"] = df['dataset'].map(mapping_dict)

In [148]:
df = df.query("(dataset == 'russell_250' | dataset == 'Heart_1') ")

In [149]:
df

Unnamed: 0,path,method,modality,dataset_features,dataset,features,mdata_spatial_path
2,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_variable,russell_250,highly_variable,/vol/storage/data/simulations/russell_250.h5mu
4,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_accessible,russell_250,highly_accessible,/vol/storage/data/simulations/russell_250.h5mu
5,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_accessible,russell_250,highly_accessible,/vol/storage/data/simulations/russell_250.h5mu
9,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,Heart_1_highly_accessible,Heart_1,highly_accessible,/vol/storage/data/simulations/Heart_1.h5mu
16,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,Heart_1_highly_variable,Heart_1,highly_variable,/vol/storage/data/simulations/Heart_1.h5mu
16,/vol/storage/data/deconvolution_results/test/a...,tangram,atac,Heart_1_highly_variable,Heart_1,highly_variable,/vol/storage/data/simulations/Heart_1.h5mu


In [150]:
def load_table(path, index_col):
    res = pd.read_csv(path, index_col=index_col)
    if "q05cell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("q05cell_abundance_w_sf_", expand=True).loc[:, 1].values
    elif "meanscell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("meanscell_abundance_w_sf_", expand=True).loc[:, 1].values
    if res.index[0] != 0:
        res.index = res.index.astype(int) - 1
    res.index = res.index.astype(str)
    if "cell_ID" in res.columns:
        res.drop("cell_ID", axis=1, inplace=True)
    res = res.div(res.sum(axis=1), axis=0)
    return res

In [151]:
df = pd.concat([df, df.iloc[[-1]]])

In [152]:
df["path"].iloc[-1] = "/vol/storage/data/deconvolution_results/test/atac/Heart_1_highly_variable/tangram_ct_pred.csv"

In [153]:
df

Unnamed: 0,path,method,modality,dataset_features,dataset,features,mdata_spatial_path
2,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_variable,russell_250,highly_variable,/vol/storage/data/simulations/russell_250.h5mu
4,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_accessible,russell_250,highly_accessible,/vol/storage/data/simulations/russell_250.h5mu
5,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,russell_250_highly_accessible,russell_250,highly_accessible,/vol/storage/data/simulations/russell_250.h5mu
9,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,Heart_1_highly_accessible,Heart_1,highly_accessible,/vol/storage/data/simulations/Heart_1.h5mu
16,/vol/storage/data/deconvolution_results/tangra...,tangram,atac,Heart_1_highly_variable,Heart_1,highly_variable,/vol/storage/data/simulations/Heart_1.h5mu
16,/vol/storage/data/deconvolution_results/test/a...,tangram,atac,Heart_1_highly_variable,Heart_1,highly_variable,/vol/storage/data/simulations/Heart_1.h5mu
16,/vol/storage/data/deconvolution_results/test/a...,tangram,atac,Heart_1_highly_variable,Heart_1,highly_variable,/vol/storage/data/simulations/Heart_1.h5mu


In [None]:
def get_proportions(adata):
    df = pd.DataFrame(adata.obsm["proportions"], columns=adata.uns["proportion_names"], index=adata.obs_names)
    return df

In [154]:
jsd = []
rmse = []
for _, row in tqdm.tqdm(df.iterrows()):
    if i == 
    # load ground truth
    target_adata = mu.read(row["mdata_spatial_path"])
    targets = get_proportions(target_adata[row["modality"]])

    # load table
    predictions = load_table(row["path"], index_col=(None if row["method"] == "moscot" else 0))
    predictions = predictions.loc[targets.index, targets.columns]
    jsd.append(de.tl.jsd(predictions, targets))
    rmse.append(de.tl.rmse(predictions, targets))
df["jsd"] = jsd
df["rmse"] = rmse

5it [00:23,  4.74s/it]


KeyError: "['Adipocyte', 'Atrial Cardiomyocyte', 'Endothelial cell', 'Fibroblast', 'Lymphatic Endothelial cell', 'Lymphoid', 'Myeloid', 'Neural cell', 'Ventricular Cardiomyocyte'] not in index"

In [144]:
load_table(row["path"], index_col=(None if row["method"] == "moscot" else 0))

Unnamed: 0,Mural cell,Mast cell,Mesothelial cell
0,0.941892,0.010925,0.047183
1,0.585719,0.414263,0.000018
2,0.920166,0.016500,0.063334
3,0.842040,0.080708,0.077253
4,0.761817,0.153722,0.084461
...,...,...,...
956,0.871320,0.073932,0.054749
957,0.828523,0.143127,0.028349
958,0.611134,0.355417,0.033449
959,0.957805,0.001252,0.040943
