In [1]:
import mudata as mu
from deconvatac.tl import tangram
import pandas as pd
import numpy as np
import seml
import pandas as pd
import glob
import deconvatac as de
import seaborn as sns
import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


## Run Tangram

In [2]:
class ExperimentWrapper:
    """
    A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes.
    This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by
    specific method. This avoids having one large "main" function which takes all parameters as input.
    """

    def __init__(self, init_all=True):
        if init_all:
            self.init_all()

    def init_dataset(self, mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality):

        self.spatial_path = mdata_spatial_path
        self.adata_spatial = mu.read_h5mu(mdata_spatial_path).mod[modality]
        self.adata_reference = mu.read_h5mu(mdata_reference_path).mod[modality]
        # subset on HVFs
        self.adata_spatial = self.adata_spatial[:, self.adata_reference.var[var_HVF_column]]
        self.adata_reference = self.adata_reference[:, self.adata_reference.var[var_HVF_column]]

        self.modality = modality
        self.labels_key = labels_key
        self.var_HVF_column = var_HVF_column

    def init_method(self, method_id):
        self.method_id = method_id

    def init_all(self):
        self.init_dataset()
        self.init_method()

    def run(self, output_path):

        dataset = self.spatial_path.split("/")[-1].split(".")[0]
        dataset_var_column = dataset + "_" + self.var_HVF_column
        output_path = output_path + self.modality + "/" + dataset_var_column

        tangram(
            adata_spatial=self.adata_spatial,
            adata_ref=self.adata_reference,
            labels_key=self.labels_key,
            run_rank_genes=False,
            result_path=output_path,
            device="cuda:0",
            num_epochs=1000,
        )

        results = {
            "result_path": output_path + "/tangram_ct_pred.csv",
            "dataset": dataset,
            "modality": self.modality,
            "var_HVF_column": self.var_HVF_column,
        }
        return results

### Create new reference


In [3]:
def create_reference(parquet, ref_path): 
    sample_cells = pd.read_parquet(parquet)
    ref = mu.read_h5mu("/vol/storage/data/cellxgene/human_cardiac_niches/human_cardiac_niches.h5mu")
    cell_ids = np.concatenate(sample_cells['cell_id'].values)
    ref = ref[np.unique(cell_ids)]
    print(ref)
    ref.write(ref_path)

In [4]:
parquets = ["/vol/storage/data/simulations/test/Heart_1.pq", "/vol/storage/data/simulations/test/Heart_2.pq", 
            "/vol/storage/data/simulations/test/Heart_3.pq", "/vol/storage/data/simulations/test/Heart_4.pq"]
ref_paths = ["/vol/storage/data/simulations/test_ref/Heart1_ref.h5mu", "/vol/storage/data/simulations/test_ref/Heart2_ref.h5mu"
             , "/vol/storage/data/simulations/test_ref/Heart3_ref.h5mu", "/vol/storage/data/simulations/test_ref/Heart4_ref.h5mu"]

In [None]:
for i in range(len(parquets)):
    create_reference(parquets[i], ref_paths[i])

### Run Tangram

In [6]:
def run_tangram(ref_path, spatial_path, modality): 
    mdata_reference_path = ref_path
    mdata_spatial_path = spatial_path
    method_id =  "Tangram"
    output_path =  "/vol/storage/data/deconvolution_results/test2/"
    labels_key = "cell_type"
    modality = modality
    var_HVF_column = "highly_variable"
    ex = ExperimentWrapper(init_all=False)
    ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
    ex.init_method(method_id)
    ex.run(output_path)
    if modality == "atac": 
        var_HVF_column = "highly_accessible"
        ex = ExperimentWrapper(init_all=False)
        ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
        ex.init_method(method_id)
        ex.run(output_path)



In [7]:
spatial_paths = ["/vol/storage/data/simulations/test/Heart_1.h5mu", "/vol/storage/data/simulations/test/Heart_2.h5mu", 
                 "/vol/storage/data/simulations/test/Heart_3.h5mu", "/vol/storage/data/simulations/test/Heart_4.h5mu"]

In [None]:
for i in range(len(spatial_paths)): 
    run_tangram(ref_paths[i], spatial_paths[i], "rna")
    run_tangram(ref_paths[i], spatial_paths[i], "atac")


## Evaluate the results

In [28]:
data_path = "/vol/storage/data/deconvolution_results/test2"
methods = ["tangram"]
modalities = ["atac", "rna"]

In [29]:
# go over all methods and modalities
df = [pd.DataFrame({'path': glob.glob(os.path.join(data_path, modality, "*", "*"))}) for modality in modalities]
df = pd.concat(df)
df

Unnamed: 0,path
0,/vol/storage/data/deconvolution_results/test2/...
1,/vol/storage/data/deconvolution_results/test2/...
2,/vol/storage/data/deconvolution_results/test2/...
3,/vol/storage/data/deconvolution_results/test2/...
4,/vol/storage/data/deconvolution_results/test2/...
5,/vol/storage/data/deconvolution_results/test2/...
6,/vol/storage/data/deconvolution_results/test2/...
7,/vol/storage/data/deconvolution_results/test2/...
0,/vol/storage/data/deconvolution_results/test2/...
1,/vol/storage/data/deconvolution_results/test2/...


In [30]:
df[['modality', 'dataset_features']] = df['path'].str.split('/', expand=True).iloc[:, 6:-1]
df[['method']] = "tangram"
df['dataset'] = df['dataset_features'].str.rsplit("_", n=2).str[0]
df["features"] = df["dataset_features"].str.split("_", n=2).str[-1]
df

Unnamed: 0,path,modality,dataset_features,method,dataset,features
0,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_accessible,tangram,Heart_3,highly_accessible
1,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_accessible,tangram,Heart_1,highly_accessible
2,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_variable,tangram,Heart_4,highly_variable
3,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_accessible,tangram,Heart_2,highly_accessible
4,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_variable,tangram,Heart_3,highly_variable
5,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_accessible,tangram,Heart_4,highly_accessible
6,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_variable,tangram,Heart_2,highly_variable
7,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_variable,tangram,Heart_1,highly_variable
0,/vol/storage/data/deconvolution_results/test2/...,rna,Heart_4_highly_variable,tangram,Heart_4,highly_variable
1,/vol/storage/data/deconvolution_results/test2/...,rna,Heart_3_highly_variable,tangram,Heart_3,highly_variable


In [31]:
mapping_dict = {
    "russell_250": "/vol/storage/data/simulations/test/russell_250.h5mu",
    "Heart_1": "/vol/storage/data/simulations/test/Heart_1.h5mu",
    "Heart_2": "/vol/storage/data/simulations/test/Heart_2.h5mu",
    "Heart_3": "/vol/storage/data/simulations/test/Heart_3.h5mu",
    "Heart_4": "/vol/storage/data/simulations/test/Heart_4.h5mu",
    "Brain_1": "/vol/storage/data/simulations/test/Brain_1.h5mu",
    "Brain_2": "/vol/storage/data/simulations/test/Brain_2.h5mu",
    "Brain_3": "/vol/storage/data/simulations/test/Brain_3.h5mu",
    "Brain_4": "/vol/storage/data/simulations/test/Brain_4.h5mu",
}

In [32]:
df["mdata_spatial_path"] = df['dataset'].map(mapping_dict)
df

Unnamed: 0,path,modality,dataset_features,method,dataset,features,mdata_spatial_path
0,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_accessible,tangram,Heart_3,highly_accessible,/vol/storage/data/simulations/test/Heart_3.h5mu
1,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_accessible,tangram,Heart_1,highly_accessible,/vol/storage/data/simulations/test/Heart_1.h5mu
2,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_variable,tangram,Heart_4,highly_variable,/vol/storage/data/simulations/test/Heart_4.h5mu
3,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_accessible,tangram,Heart_2,highly_accessible,/vol/storage/data/simulations/test/Heart_2.h5mu
4,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_variable,tangram,Heart_3,highly_variable,/vol/storage/data/simulations/test/Heart_3.h5mu
5,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_accessible,tangram,Heart_4,highly_accessible,/vol/storage/data/simulations/test/Heart_4.h5mu
6,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_variable,tangram,Heart_2,highly_variable,/vol/storage/data/simulations/test/Heart_2.h5mu
7,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_variable,tangram,Heart_1,highly_variable,/vol/storage/data/simulations/test/Heart_1.h5mu
0,/vol/storage/data/deconvolution_results/test2/...,rna,Heart_4_highly_variable,tangram,Heart_4,highly_variable,/vol/storage/data/simulations/test/Heart_4.h5mu
1,/vol/storage/data/deconvolution_results/test2/...,rna,Heart_3_highly_variable,tangram,Heart_3,highly_variable,/vol/storage/data/simulations/test/Heart_3.h5mu


In [33]:
def load_table(path, index_col):
    res = pd.read_csv(path, index_col=index_col)
    if "q05cell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("q05cell_abundance_w_sf_", expand=True).loc[:, 1].values
    elif "meanscell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("meanscell_abundance_w_sf_", expand=True).loc[:, 1].values
    if res.index[0] != 0:
        res.index = res.index.astype(int) - 1
    res.index = res.index.astype(str)
    if "cell_ID" in res.columns:
        res.drop("cell_ID", axis=1, inplace=True)
    res = res.div(res.sum(axis=1), axis=0)
    return res

In [34]:
def get_proportions(adata):
    df = pd.DataFrame(adata.obsm["proportions"], columns=adata.uns["proportion_names"], index=adata.obs_names)
    return df

In [44]:
jsd = []
rmse = []
for _, row in tqdm.tqdm(df.iterrows()):
    # load ground truth
    target_adata = mu.read(row["mdata_spatial_path"])
    targets = get_proportions(target_adata[row["modality"]])

    # load table
    predictions = load_table(row["path"], index_col=(None if row["method"] == "moscot" else 0))
    missing_cell_types = [cell_type for cell_type in targets.columns if cell_type not in predictions.columns]
    predictions = predictions.assign(**dict.fromkeys(missing_cell_types, 0))
    predictions = predictions.loc[targets.index, targets.columns]
    jsd.append(de.tl.jsd(predictions, targets))
    rmse.append(de.tl.rmse(predictions, targets))
df["jsd"] = jsd
df["rmse"] = rmse

12it [00:22,  1.87s/it]


In [46]:
df.groupby(['method', 'features', 'modality'])[['jsd']].mean().sum(axis=1).sort_values()

method   features           modality
tangram  highly_variable    rna         0.369342
         highly_accessible  atac        0.372315
         highly_variable    atac        0.374485
dtype: float64

In [47]:
df.groupby(['method', 'features', 'modality'])[['rmse']].mean().sum(axis=1).sort_values()

method   features           modality
tangram  highly_variable    rna         0.097772
                            atac        0.100740
         highly_accessible  atac        0.101655
dtype: float64