In [None]:
import glob
import os

import mudata as mu
import numpy as np
import pandas as pd
import tqdm

import deconvatac as de
from deconvatac.tl import tangram


  from .autonotebook import tqdm as notebook_tqdm


## Run Tangram

In [None]:
class ExperimentWrapper:
    """
    A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes.
    This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by
    specific method. This avoids having one large "main" function which takes all parameters as input.
    """

    def __init__(self, init_all=True):
        if init_all:
            self.init_all()

    def init_dataset(self, mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality):

        self.spatial_path = mdata_spatial_path
        self.adata_spatial = mu.read_h5mu(mdata_spatial_path).mod[modality]
        self.adata_reference = mu.read_h5mu(mdata_reference_path).mod[modality]
        # subset on HVFs
        self.adata_spatial = self.adata_spatial[:, self.adata_reference.var[var_HVF_column]]
        self.adata_reference = self.adata_reference[:, self.adata_reference.var[var_HVF_column]]

        self.modality = modality
        self.labels_key = labels_key
        self.var_HVF_column = var_HVF_column

    def init_method(self, method_id):
        self.method_id = method_id

    def init_all(self):
        self.init_dataset()
        self.init_method()

    def run(self, output_path):

        dataset = self.spatial_path.split("/")[-1].split(".")[0]
        dataset_var_column = dataset + "_" + self.var_HVF_column
        output_path = output_path + self.modality + "/" + dataset_var_column

        tangram(
            adata_spatial=self.adata_spatial,
            adata_ref=self.adata_reference,
            labels_key=self.labels_key,
            run_rank_genes=False,
            result_path=output_path,
            device="cuda:0",
            num_epochs=1000,
        )

        results = {
            "result_path": output_path + "/tangram_ct_pred.csv",
            "dataset": dataset,
            "modality": self.modality,
            "var_HVF_column": self.var_HVF_column,
        }
        return results

### Create new reference


In [None]:
def create_reference(parquet, ref_path=None, power=None): 
    if power is None: 
        sample_cells = pd.read_parquet(parquet)
        ref = mu.read_h5mu("/vol/storage/submission_data/data/human_cardiac_niches.h5mu")
        cell_ids = np.concatenate(sample_cells['cell_id'].values)
        ref = ref[np.unique(cell_ids)]
        print(ref)
        ref.write(ref_path)
    else: 
        sample_cells = pd.read_parquet(parquet)
        cell_ids = np.unique(np.concatenate(sample_cells['cell_id'].values))
        ref = mu.read_h5mu("/vol/storage/submission_data/data/human_cardiac_niches.h5mu")
        ids = list(set(ref.obs.index) - set(np.unique(cell_ids)))
        for i in range(len(power)):
            n_draw = len(cell_ids)*2**power[i]
            if n_draw > len(ids):
                n_draw = len(ids)
            drawn_ids = np.random.choice(ids, n_draw, replace=False)
            resulting_ids = np.concatenate([cell_ids, drawn_ids])
            ref_new = ref[resulting_ids].copy()
            print(ref_new)
            ref_new.write("/vol/storage/submission_data/data/simulations/test/power" + str(power[i]) + "/"+ parquet.split("/")[-1].split(".")[0] + "_ref.h5mu")

In [4]:
parquets = ["/vol/storage/data/simulations/test/Heart_1.pq", "/vol/storage/data/simulations/test/Heart_2.pq", 
            "/vol/storage/data/simulations/test/Heart_3.pq", "/vol/storage/data/simulations/test/Heart_4.pq"]
ref_paths = ["/vol/storage/data/simulations/test/Heart1_ref.h5mu", "/vol/storage/data/simulations/test/Heart2_ref.h5mu"
             , "/vol/storage/data/simulations/test/Heart3_ref.h5mu", "/vol/storage/data/simulations/test/Heart4_ref.h5mu"]

Create references without drawn cell ids: 

In [None]:
for i in range(len(parquets)):
    create_reference(parquets[i], ref_paths[i])

Create references with drawn cell ids: 

In [None]:
power = [0,1,2,3]
for i in range(len(parquets)):
    create_reference(parquets[i], power=power)

### Run Tangram

In [8]:
def run_tangram(ref_path, spatial_path, modality): 
    mdata_reference_path = ref_path
    mdata_spatial_path = spatial_path
    method_id =  "Tangram"
    output_path =  "/vol/storage/data/deconvolution_results/test2/"
    if "power" in ref_path.split("/")[-2]:
        output_path =  "/vol/storage/data/deconvolution_results/test2/" + ref_path.split("/")[-2] + "/"
    labels_key = "cell_type"
    modality = modality
    var_HVF_column = "highly_variable"
    ex = ExperimentWrapper(init_all=False)
    ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
    ex.init_method(method_id)
    ex.run(output_path)
    if modality == "atac": 
        var_HVF_column = "highly_accessible"
        ex = ExperimentWrapper(init_all=False)
        ex.init_dataset(mdata_spatial_path, mdata_reference_path, var_HVF_column, labels_key, modality)
        ex.init_method(method_id)
        ex.run(output_path)



In [3]:
spatial_paths = ["/vol/storage/data/simulations/test/Heart_1.h5mu", "/vol/storage/data/simulations/test/Heart_2.h5mu", 
                 "/vol/storage/data/simulations/test/Heart_3.h5mu", "/vol/storage/data/simulations/test/Heart_4.h5mu"]

In [None]:
for i in range(len(spatial_paths)): 
    run_tangram(ref_paths[i], spatial_paths[i], "rna")
    run_tangram(ref_paths[i], spatial_paths[i], "atac")

In [None]:
power = [0,1,2,3]
for j in range(len(spatial_paths)):
    for i in range(len(power)): 
        ref_path = "/vol/storage/data/simulations/test/power" + str(power[i]) + "/"+ spatial_paths[j].split("/")[-1].split(".")[0] + "_ref.h5mu"
        print(ref_path, spatial_paths[j])
        run_tangram(ref_path, spatial_paths[j], "rna")
        run_tangram(ref_path, spatial_paths[j], "atac")

## Evaluate results

In [6]:
def get_proportions(adata):
    df = pd.DataFrame(adata.obsm["proportions"], columns=adata.uns["proportion_names"], index=adata.obs_names)
    return df

def load_table(path, index_col):
    res = pd.read_csv(path, index_col=index_col)
    if "q05cell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("q05cell_abundance_w_sf_", expand=True).loc[:, 1].values
    elif "meanscell_abundance_w_sf_" in res.columns[0]:
        res.columns = res.columns.to_series().str.split("meanscell_abundance_w_sf_", expand=True).loc[:, 1].values
    if res.index[0] != 0:
        res.index = res.index.astype(int) - 1
    res.index = res.index.astype(str)
    if "cell_ID" in res.columns:
        res.drop("cell_ID", axis=1, inplace=True)
    res = res.div(res.sum(axis=1), axis=0)
    return res

In [7]:
def evaluate_results(data_paths,  modalities, mapping_dict, results_path):
    if "power" not in data_paths[0]:
        df = [pd.DataFrame({'path': glob.glob(os.path.join(data_paths[0], modality, "*", "*"))}) for modality in modalities]
        df = pd.concat(df)
        df[['modality', 'dataset_features']] = df['path'].str.split('/', expand=True).iloc[:, 6:-1]
    else: 
        df = [pd.DataFrame({'path': glob.glob(os.path.join(data_path, modality, "*", "*"))}) for data_path in data_paths for modality in modalities]
        df = pd.concat(df)
        df[['power','modality', 'dataset_features']] = df['path'].str.split('/', expand=True).iloc[:, 6:-1]
    df[['method']] = "tangram"
    df['dataset'] = df['dataset_features'].str.rsplit("_", n=2).str[0]
    df["features"] = df["dataset_features"].str.split("_", n=2).str[-1]
    df["mdata_spatial_path"] = df['dataset'].map(mapping_dict)

    jsd = []
    rmse = []
    for _, row in tqdm.tqdm(df.iterrows()):
        # load ground truth
        target_adata = mu.read(row["mdata_spatial_path"])
        targets = get_proportions(target_adata[row["modality"]])

        # load table
        predictions = load_table(row["path"], index_col=(None if row["method"] == "moscot" else 0))
        missing_cell_types = [cell_type for cell_type in targets.columns if cell_type not in predictions.columns]
        predictions = predictions.assign(**dict.fromkeys(missing_cell_types, 0))
        predictions = predictions.loc[targets.index, targets.columns]
        jsd.append(de.tl.jsd(predictions, targets))
        rmse.append(de.tl.rmse(predictions, targets))
    df["jsd"] = jsd
    df["rmse"] = rmse
    
    df.to_csv(results_path)
    

In [8]:
mapping_dict = {
    "russell_250": "/vol/storage/data/simulations/test/russell_250.h5mu",
    "Heart_1": "/vol/storage/data/simulations/test/Heart_1.h5mu",
    "Heart_2": "/vol/storage/data/simulations/test/Heart_2.h5mu",
    "Heart_3": "/vol/storage/data/simulations/test/Heart_3.h5mu",
    "Heart_4": "/vol/storage/data/simulations/test/Heart_4.h5mu",
    "Brain_1": "/vol/storage/data/simulations/test/Brain_1.h5mu",
    "Brain_2": "/vol/storage/data/simulations/test/Brain_2.h5mu",
    "Brain_3": "/vol/storage/data/simulations/test/Brain_3.h5mu",
    "Brain_4": "/vol/storage/data/simulations/test/Brain_4.h5mu",
}

In [25]:
data_path = ["/vol/storage/data/deconvolution_results/test2"]
methods = ["tangram"]
modalities = ["atac", "rna"]

In [None]:
evaluate_results(data_paths=data_path, modalities=modalities, mapping_dict=mapping_dict, results_path="../results/tables/results_table_tangram_no_draw.csv")

12it [00:23,  1.98s/it]


In [None]:
df = pd.read_csv('../results/tables/results_table_tangram_no_draw.csv', index_col=0)
df.head()

Unnamed: 0,path,modality,dataset_features,method,dataset,features,mdata_spatial_path,jsd,rmse
0,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_accessible,tangram,Heart_3,highly_accessible,/vol/storage/data/simulations/test/Heart_3.h5mu,0.326849,0.068327
1,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_1_highly_accessible,tangram,Heart_1,highly_accessible,/vol/storage/data/simulations/test/Heart_1.h5mu,0.442848,0.179567
2,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_4_highly_variable,tangram,Heart_4,highly_variable,/vol/storage/data/simulations/test/Heart_4.h5mu,0.401321,0.096223
3,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_2_highly_accessible,tangram,Heart_2,highly_accessible,/vol/storage/data/simulations/test/Heart_2.h5mu,0.327318,0.061107
4,/vol/storage/data/deconvolution_results/test2/...,atac,Heart_3_highly_variable,tangram,Heart_3,highly_variable,/vol/storage/data/simulations/test/Heart_3.h5mu,0.325744,0.06592


In [5]:
df.groupby(['method', 'features', 'modality'])[['jsd']].mean().sum(axis=1).sort_values()

method   features           modality
tangram  highly_variable    rna         0.369342
         highly_accessible  atac        0.372315
         highly_variable    atac        0.374485
dtype: float64

In [6]:
df.groupby(['method', 'features', 'modality'])[['rmse']].mean().sum(axis=1).sort_values()

method   features           modality
tangram  highly_variable    rna         0.097772
                            atac        0.100740
         highly_accessible  atac        0.101655
dtype: float64

In [9]:
data_paths = ["/vol/storage/data/deconvolution_results/test2/power0/", "/vol/storage/data/deconvolution_results/test2/power1/",
                "/vol/storage/data/deconvolution_results/test2/power2/", "/vol/storage/data/deconvolution_results/test2/power3/"]
methods = ["tangram"]
modalities = ["atac", "rna"]

In [None]:
evaluate_results(data_paths=data_paths, modalities=modalities, mapping_dict=mapping_dict, results_path="../results/tables/results_table_tangram_draw.csv")

0it [00:00, ?it/s]

48it [01:35,  1.98s/it]


In [None]:
df = pd.read_csv('../results/tables/results_table_tangram_draw.csv', index_col=0)
df.head()

Unnamed: 0,path,power,modality,dataset_features,method,dataset,features,mdata_spatial_path,jsd,rmse
0,/vol/storage/data/deconvolution_results/test2/...,power0,atac,Heart_3_highly_accessible,tangram,Heart_3,highly_accessible,/vol/storage/data/simulations/test/Heart_3.h5mu,0.447103,0.09537
1,/vol/storage/data/deconvolution_results/test2/...,power0,atac,Heart_1_highly_accessible,tangram,Heart_1,highly_accessible,/vol/storage/data/simulations/test/Heart_1.h5mu,0.6279,0.159803
2,/vol/storage/data/deconvolution_results/test2/...,power0,atac,Heart_4_highly_variable,tangram,Heart_4,highly_variable,/vol/storage/data/simulations/test/Heart_4.h5mu,0.518654,0.116674
3,/vol/storage/data/deconvolution_results/test2/...,power0,atac,Heart_2_highly_accessible,tangram,Heart_2,highly_accessible,/vol/storage/data/simulations/test/Heart_2.h5mu,0.441142,0.081335
4,/vol/storage/data/deconvolution_results/test2/...,power0,atac,Heart_3_highly_variable,tangram,Heart_3,highly_variable,/vol/storage/data/simulations/test/Heart_3.h5mu,0.440834,0.093058


In [14]:
df.groupby(['power','method', 'features', 'modality'])[['jsd']].mean().sum(axis=1).sort_values()

power   method   features           modality
power0  tangram  highly_variable    rna         0.503227
                 highly_accessible  atac        0.505794
                 highly_variable    atac        0.506328
power1  tangram  highly_variable    rna         0.560245
                                    atac        0.569520
                 highly_accessible  atac        0.572636
power2  tangram  highly_variable    rna         0.615094
                                    atac        0.630930
                 highly_accessible  atac        0.635952
power3  tangram  highly_variable    rna         0.657895
                                    atac        0.677503
                 highly_accessible  atac        0.685836
dtype: float64

In [15]:
df.groupby(['power','method', 'features', 'modality'])[['rmse']].mean().sum(axis=1).sort_values()

power   method   features           modality
power0  tangram  highly_variable    rna         0.110739
                                    atac        0.112276
                 highly_accessible  atac        0.113434
power1  tangram  highly_variable    rna         0.124068
                                    atac        0.126064
                 highly_accessible  atac        0.127212
power2  tangram  highly_variable    rna         0.138309
                                    atac        0.141058
                 highly_accessible  atac        0.142221
power3  tangram  highly_variable    rna         0.149177
                                    atac        0.153051
                 highly_accessible  atac        0.154847
dtype: float64