In [1]:
import mudata as mu
from deconvatac.tl import tangram
import pandas as pd
import numpy as np
import seml
import pandas as pd
import glob
import deconvatac as de
import seaborn as sns
import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


## Run Tangram on Cluster Mode

In [16]:
class ExperimentWrapper:
    """
    A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes.
    This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by
    specific method. This avoids having one large "main" function which takes all parameters as input.
    """

    def __init__(self, init_all=True):
        if init_all:
            self.init_all()

    def init_dataset(self, mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality):

        self.spatial_path = mdata_spatial_path
        self.adata_spatial = mu.read_h5mu(mdata_spatial_path).mod[modality]
        self.adata_reference = mu.read_h5mu(mdata_reference_path).mod[modality]
        # subset on HVFs
        self.adata_spatial = self.adata_spatial[:, self.adata_reference.var[var_HVF_column]]
        self.adata_reference = self.adata_reference[:, self.adata_reference.var[var_HVF_column]]

        self.modality = modality
        self.labels_key = labels_key
        self.var_HVF_column = var_HVF_column

    def init_method(self, method_id):
        self.method_id = method_id

    def init_all(self):
        self.init_dataset()
        self.init_method()

    def run(self, output_path):

        dataset = self.spatial_path.split("/")[-1].split(".")[0]
        dataset_var_column = dataset + "_" + self.var_HVF_column
        output_path = output_path + self.modality + "/" + dataset_var_column

        tangram(
            adata_spatial=self.adata_spatial,
            adata_ref=self.adata_reference,
            labels_key=self.labels_key,
            run_rank_genes=False,
            result_path=output_path,
            device="cuda:0",
            num_epochs=1000,
            **{"cluster_label": self.labels_key, "mode": "clusters"},
        )

        results = {
            "result_path": output_path + "/tangram_ct_pred.csv",
            "dataset": dataset,
            "modality": self.modality,
            "var_HVF_column": self.var_HVF_column,
        }
        return results

In [13]:
def run_tangram(ref_path, spatial_path, modality): 
    mdata_reference_path = ref_path
    mdata_spatial_path = spatial_path
    method_id =  "Tangram"
    output_path =  "/vol/storage/data/deconvolution_results/test2/cluster_mode/" 
    labels_key = "cell_type"
    modality = modality
    var_HVF_column = "highly_variable"
    ex = ExperimentWrapper(init_all=False)
    ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
    ex.init_method(method_id)
    ex.run(output_path)
    if modality == "atac": 
        var_HVF_column = "highly_accessible"
        ex = ExperimentWrapper(init_all=False)
        ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
        ex.init_method(method_id)
        ex.run(output_path)


Heart Datasets

In [None]:
ref_paths = "/vol/storage/data/simulations/human_cardiac_niches.h5mu"

In [5]:
spatial_paths = ["/vol/storage/data/simulations/Heart_1.h5mu", "/vol/storage/data/simulations/Heart_2.h5mu", 
                 "/vol/storage/data/simulations/Heart_3.h5mu", "/vol/storage/data/simulations/Heart_4.h5mu"]

In [None]:
for i in range(len(spatial_paths)): 
    run_tangram(ref_paths, spatial_paths[i], "rna")
    run_tangram(ref_paths, spatial_paths[i], "atac")

Russell Dataset

In [14]:
ref_path = "/vol/storage/data/simulations/russel_ref.h5mu"
spatial_path = "/vol/storage/data/simulations/russell_250.h5mu"

In [17]:
run_tangram(ref_path, spatial_path, "rna")
run_tangram(ref_path, spatial_path, "atac")

  if not is_categorical_dtype(df_full[k]):
INFO:root:4000 training genes are saved in `uns``training_genes` of both single cell and spatial Anndatas.
INFO:root:4000 overlapped genes are saved in `uns``overlap_genes` of both single cell and spatial Anndatas.
INFO:root:uniform based density prior is calculated and saved in `obs``uniform_density` of the spatial Anndata.
INFO:root:rna count based density prior is calculated and saved in `obs``rna_count_based_density` of the spatial Anndata.
INFO:root:Allocate tensors for mapping.
INFO:root:Begin training with 4000 genes and rna_count_based density_prior in clusters mode...


Running tangram


INFO:root:Printing scores every 100 epochs.


Score: 0.260, KL reg: 0.553
Score: 0.497, KL reg: 0.003
Score: 0.498, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.499, KL reg: 0.003
Score: 0.500, KL reg: 0.003


INFO:root:Saving results..
INFO:root:spatial prediction dataframe is saved in `obsm` `tangram_ct_pred` of the spatial AnnData.


Tangram Done, now saving results


INFO:root:20000 training genes are saved in `uns``training_genes` of both single cell and spatial Anndatas.
INFO:root:20000 overlapped genes are saved in `uns``overlap_genes` of both single cell and spatial Anndatas.
INFO:root:uniform based density prior is calculated and saved in `obs``uniform_density` of the spatial Anndata.
INFO:root:rna count based density prior is calculated and saved in `obs``rna_count_based_density` of the spatial Anndata.
INFO:root:Allocate tensors for mapping.
INFO:root:Begin training with 20000 genes and rna_count_based density_prior in clusters mode...
INFO:root:Printing scores every 100 epochs.


Running tangram
Score: 0.282, KL reg: 0.563
Score: 0.517, KL reg: 0.001
Score: 0.519, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001
Score: 0.520, KL reg: 0.001


INFO:root:Saving results..
INFO:root:spatial prediction dataframe is saved in `obsm` `tangram_ct_pred` of the spatial AnnData.


Tangram Done, now saving results


INFO:root:20000 training genes are saved in `uns``training_genes` of both single cell and spatial Anndatas.
INFO:root:20000 overlapped genes are saved in `uns``overlap_genes` of both single cell and spatial Anndatas.
INFO:root:uniform based density prior is calculated and saved in `obs``uniform_density` of the spatial Anndata.
INFO:root:rna count based density prior is calculated and saved in `obs``rna_count_based_density` of the spatial Anndata.
INFO:root:Allocate tensors for mapping.
INFO:root:Begin training with 20000 genes and rna_count_based density_prior in clusters mode...
INFO:root:Printing scores every 100 epochs.


Running tangram
Score: 0.294, KL reg: 0.570
Score: 0.557, KL reg: 0.000
Score: 0.558, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000
Score: 0.559, KL reg: 0.000


INFO:root:Saving results..
INFO:root:spatial prediction dataframe is saved in `obsm` `tangram_ct_pred` of the spatial AnnData.


Tangram Done, now saving results


## Evaluate results

In [18]:
def get_proportions(adata):
    df = pd.DataFrame(adata.obsm["proportions"], columns=adata.uns["proportion_names"], index=adata.obs_names)
    return df

def load_table(path, index_col):
    res = pd.read_csv(path, index_col=index_col)
    if "q05cell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("q05cell_abundance_w_sf_", expand=True).loc[:, 1].values
    elif "meanscell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("meanscell_abundance_w_sf_", expand=True).loc[:, 1].values
    if res.index[0] != 0:
        res.index = res.index.astype(int) - 1
    res.index = res.index.astype(str)
    if "cell_ID" in res.columns:
        res.drop("cell_ID", axis=1, inplace=True)
    res = res.div(res.sum(axis=1), axis=0)
    return res

In [19]:
def evaluate_results(data_paths,  modalities, mapping_dict, results_path):
    df = [pd.DataFrame({'path': glob.glob(os.path.join(data_paths[0], modality, "*", "*"))}) for modality in modalities]
    df = pd.concat(df)
    df[['modality', 'dataset_features']] = df['path'].str.split('/', expand=True).iloc[:, 7:-1]
    df[['method']] = "tangram"
    df['dataset'] = df['dataset_features'].str.rsplit("_", n=2).str[0]
    df["features"] = df["dataset_features"].str.split("_", n=2).str[-1]
    df["mdata_spatial_path"] = df['dataset'].map(mapping_dict)

    jsd = []
    rmse = []
    for _, row in tqdm.tqdm(df.iterrows()):
        # load ground truth
        target_adata = mu.read(row["mdata_spatial_path"])
        targets = get_proportions(target_adata[row["modality"]])

        # load table
        predictions = load_table(row["path"], index_col=(None if row["method"] == "moscot" else 0))
        missing_cell_types = [cell_type for cell_type in targets.columns if cell_type not in predictions.columns]
        predictions = predictions.assign(**dict.fromkeys(missing_cell_types, 0))
        predictions = predictions.loc[targets.index, targets.columns]
        jsd.append(de.tl.jsd(predictions, targets))
        rmse.append(de.tl.rmse(predictions, targets))
    df["jsd"] = jsd
    df["rmse"] = rmse
    
    df.to_csv(results_path)
    

In [23]:
mapping_dict = {
    "russell_250": "/vol/storage/data/simulations/russell_250.h5mu",
    "Heart_1": "/vol/storage/data/simulations/Heart_1.h5mu",
    "Heart_2": "/vol/storage/data/simulations/Heart_2.h5mu",
    "Heart_3": "/vol/storage/data/simulations/Heart_3.h5mu",
    "Heart_4": "/vol/storage/data/simulations/Heart_4.h5mu",
    "Brain_1": "/vol/storage/data/simulations/Brain_1.h5mu",
    "Brain_2": "/vol/storage/data/simulations/Brain_2.h5mu",
    "Brain_3": "/vol/storage/data/simulations/Brain_3.h5mu",
    "Brain_4": "/vol/storage/data/simulations/Brain_4.h5mu",
}

In [None]:
data_path = ["/vol/storage/data/deconvolution_results/test2/cluster_mode"]
methods = ["tangram"]
modalities = ["atac", "rna"]

In [24]:
evaluate_results(data_paths=data_path, modalities=modalities, mapping_dict=mapping_dict, results_path="results_table_cluster_mode.csv")

0it [00:00, ?it/s]

14it [02:21, 10.11s/it]


In [25]:
df = pd.read_csv('results_table_cluster_mode.csv', index_col=0)
df.head()

Unnamed: 0,path,modality,dataset_features,method,dataset,features,mdata_spatial_path,jsd,rmse
0,/vol/storage/data/deconvolution_results/test2/...,atac,russell_250_highly_variable,tangram,russell_250,highly_variable,/vol/storage/data/simulations/russell_250.h5mu,0.494151,0.18302
1,/vol/storage/data/deconvolution_results/test2/...,atac,russell_250_highly_accessible,tangram,russell_250,highly_accessible,/vol/storage/data/simulations/russell_250.h5mu,0.508469,0.181423
2,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_accessible,tangram,Heart_1,highly_accessible,/vol/storage/data/simulations/Heart_1.h5mu,0.924113,0.241885
3,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_variable,tangram,Heart_4,highly_variable,/vol/storage/data/simulations/Heart_4.h5mu,0.722076,0.177725
4,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_accessible,tangram,Heart_2,highly_accessible,/vol/storage/data/simulations/Heart_2.h5mu,0.678814,0.154144


In [30]:
df.groupby(['method', "dataset",'features', 'modality'])[['jsd']].mean().sum(axis=1)#.sort_values()

method   dataset      features           modality
tangram  Heart_1      highly_accessible  atac        0.924113
                      highly_variable    atac        0.904940
                                         rna         0.780529
         Heart_2      highly_accessible  atac        0.678814
                      highly_variable    atac        0.624794
                                         rna         0.490643
         Heart_3      highly_variable    atac        0.630467
                                         rna         0.463121
         Heart_4      highly_accessible  atac        0.778559
                      highly_variable    atac        0.722076
                                         rna         0.571554
         russell_250  highly_accessible  atac        0.508469
                      highly_variable    atac        0.494151
                                         rna         0.462513
dtype: float64

In [31]:
df.groupby(['method', "dataset", 'features', 'modality'])[['rmse']].mean().sum(axis=1)#.sort_values()

method   dataset      features           modality
tangram  Heart_1      highly_accessible  atac        0.241885
                      highly_variable    atac        0.229489
                                         rna         0.187021
         Heart_2      highly_accessible  atac        0.154144
                      highly_variable    atac        0.135193
                                         rna         0.087696
         Heart_3      highly_variable    atac        0.149514
                                         rna         0.099352
         Heart_4      highly_accessible  atac        0.197400
                      highly_variable    atac        0.177725
                                         rna         0.127264
         russell_250  highly_accessible  atac        0.181423
                      highly_variable    atac        0.183020
                                         rna         0.166592
dtype: float64