In [1]:
import anndata
import torch
import stPlus

import squidpy as sq
import numpy as np
import scanpy as sc
import pandas as pd

from sklearn.model_selection import KFold
from transpa.eval_util import calc_corr
from transpa.util import expTransImp, leiden_cluster, compute_autocorr
from benchmark import SpaGE_impute, Tangram_impute
import warnings

warnings.filterwarnings('ignore')

seed = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
spa_adata = sc.read_h5ad("../../data/ST/melanoma/spatial.h5ad")
spa_adata = spa_adata[:, (spa_adata.var_names.values != 'MARCH1') & (spa_adata.var_names.values != 'MARCH2')].copy()
scrna_adata = sc.read_h5ad("../../data/scRNAseq/melanoma/Tirosh_raw.h5ad")
scrna_adata = scrna_adata[:, (scrna_adata.var_names.values != 'MARCH1') & (scrna_adata.var_names.values != 'MARCH2')].copy()

classes, ct_list = leiden_cluster(scrna_adata)
cls_key = 'leiden'
# sc.pp.normalize_total(spa_adata)
# sc.pp.normalize_total(scrna_adata)
# sc.pp.log1p(spa_adata)
# sc.pp.log1p(scrna_adata)

scrna_adata.obs[cls_key] = classes

In [3]:
scrna_adata, spa_adata

(AnnData object with n_obs × n_vars = 4645 × 23682
     obs: 'celltype', 'malignant', 'leiden',
 AnnData object with n_obs × n_vars = 293 × 16146
     obs: 'B', 'CAF', 'Endo', 'Macro', 'Melanoma or unclassified', 'NK', 'T'
     obsm: 'spatial')

In [4]:
sq.gr.spatial_neighbors(spa_adata, n_neighs=4)
sq.gr.spatial_autocorr(
    spa_adata,
    n_jobs=10,
)

In [5]:
spa_adata =  spa_adata[:, spa_adata.uns['moranI'].I > 0].copy()

In [6]:
raw_spatial_df  = pd.DataFrame(spa_adata.X, columns=spa_adata.var_names)
raw_spatial_df.to_csv('../../output/melanoma_raw.csv')
raw_scrna_df    = pd.DataFrame(scrna_adata.X, columns=scrna_adata.var_names)
raw_shared_gene = np.intersect1d(raw_spatial_df.columns, raw_scrna_df.columns)
raw_spatial_df.shape, raw_scrna_df.shape, raw_shared_gene.shape,

((293, 7444), (4645, 23682), (7168,))

In [7]:
np.save('../../output/melanoma_locations.npy', spa_adata.obsm['spatial'])
sq.gr.spatial_neighbors(spa_adata)


In [8]:
kf = KFold(n_splits=5, shuffle=True, random_state=0)
kf.get_n_splits(raw_shared_gene)

df_transImpSpa = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_transImpCls = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_transImpClsSpa = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_transImp = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_stplus_res = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_spaGE_res = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)
df_tangram_res = pd.DataFrame(np.zeros((spa_adata.n_obs, len(raw_shared_gene))), columns=raw_shared_gene)

for idx, (train_ind, test_ind) in enumerate(kf.split(raw_shared_gene)):    
    print(f"\n===== Fold {idx+1} =====\nNumber of train genes: {len(train_ind)}, Number of test genes: {len(test_ind)}")
    train_gene = raw_shared_gene[train_ind]
    test_gene  = raw_shared_gene[test_ind]
    
    test_spatial_df = raw_spatial_df[test_gene]
    spatial_df = raw_spatial_df[train_gene]
    scrna_df   = raw_scrna_df

    df_transImpSpa[test_gene] = expTransImp(
            df_ref=raw_scrna_df,
            df_tgt=raw_spatial_df,
            train_gene=train_gene,
            test_gene=test_gene,
            signature_mode='cell',
            mapping_mode='lowrank',
            mapping_lowdim=128,
            clip_max=2,
        #     lr=1e-2,
            spa_adj=spa_adata.obsp['spatial_connectivities'].tocoo(),
            seed=seed,
            device=device)

    corr_transImp_res = calc_corr(raw_spatial_df, df_transImpSpa, test_gene)
    print(f'fold {idx}, median correlation: {np.median(corr_transImp_res)} (TransImpSpa)')

    df_transImpCls[test_gene] = expTransImp(
            df_ref=raw_scrna_df,
            df_tgt=raw_spatial_df,
            train_gene=train_gene,
            test_gene=test_gene,
            ct_list=ct_list,
            classes=classes,
            signature_mode='cluster',
            mapping_mode='full',
            seed=seed,
            device=device)

    corr_transImp_res = calc_corr(raw_spatial_df, df_transImpCls, test_gene)
    print(f'fold {idx}, median correlation: {np.median(corr_transImp_res)} (TransImpCls)')

    df_transImp[test_gene] = expTransImp(
            df_ref=raw_scrna_df,
            df_tgt=raw_spatial_df,
            train_gene=train_gene,
            test_gene=test_gene,
            signature_mode='cell',
            mapping_mode='lowrank',
            mapping_lowdim=128,
            clip_max=2,
        #     lr=1e-2,
            seed=seed,
            device=device)

    corr_transImp_res = calc_corr(raw_spatial_df, df_transImp, test_gene)
    print(f'fold {idx}, median correlation: {np.median(corr_transImp_res)} (TransImp)')

    df_transImpClsSpa[test_gene] = expTransImp(
            df_ref=raw_scrna_df,
            df_tgt=raw_spatial_df,
            train_gene=train_gene,
            test_gene=test_gene,
            ct_list=ct_list,
            classes=classes,
            spa_adj=spa_adata.obsp['spatial_connectivities'].tocoo(),
            signature_mode='cluster',
            mapping_mode='full',
            seed=seed,
            device=device)

    corr_transImp_res = calc_corr(raw_spatial_df, df_transImpClsSpa, test_gene)
    print(f'fold {idx}, median correlation: {np.median(corr_transImp_res)} (TransImpClsSpa)')

    df_stplus_res[test_gene] = stPlus.stPlus(spatial_df, scrna_df, test_gene, "tmp_ug", verbose=False, random_seed=seed, device=device)
    corr_res_stplus = calc_corr(raw_spatial_df, df_stplus_res, test_gene)
    print(f'\t\t\t{np.median(corr_res_stplus)} (stPlus)')

    df_spaGE_res[test_gene]  = SpaGE_impute(scrna_df, spatial_df, train_gene, test_gene)
    corr_res_spaGE = calc_corr(raw_spatial_df, df_spaGE_res, test_gene)
    print(f'\t\t\t{np.median(corr_res_spaGE)} (spaGE)')

    df_tangram_res[test_gene] = Tangram_impute(scrna_adata, spa_adata, train_gene, test_gene, device, cls_key)
    corr_res_tangram = calc_corr(raw_spatial_df, df_tangram_res, test_gene)
    print(f'\t\t\t{np.median(corr_res_tangram)} (Tangram)')

corr_transImpSpa_res = calc_corr(raw_spatial_df, df_transImpSpa, raw_shared_gene)
corr_transImp_res = calc_corr(raw_spatial_df, df_transImp, raw_shared_gene)
corr_transImpCls_res = calc_corr(raw_spatial_df, df_transImpCls, raw_shared_gene)
corr_transImpClsSpa_res = calc_corr(raw_spatial_df, df_transImpClsSpa, raw_shared_gene)
corr_res_stplus = calc_corr(raw_spatial_df, df_stplus_res, raw_shared_gene)
corr_res_spaGE = calc_corr(raw_spatial_df, df_spaGE_res, raw_shared_gene)
corr_res_tangram = calc_corr(raw_spatial_df, df_tangram_res, raw_shared_gene)   

print(np.median(corr_transImpSpa_res), "(TransImpSpa)", 
      np.median(corr_transImp_res), "(TransImp)", 
      np.median(corr_transImpCls_res), "(TransImpCls)", 
      np.median(corr_transImpClsSpa_res), "(TransImpClsSpa)", 
      np.median(corr_res_stplus), "(stPlus)", 
      np.median(corr_res_spaGE), "(spaGE)",
      np.median(corr_res_tangram), "(Tangram)"
      )


===== Fold 1 =====
Number of train genes: 5734, Number of test genes: 1434


[TransImp] Epoch: 1000/1000, loss: 0.864786, (IMP) 0.862186, (SPA) 1.0 x 0.002600: 100%|██████████| 1000/1000 [00:05<00:00, 196.70it/s]


fold 0, median correlation: 0.20205526461951495 (TransImpSpa)


[TransImp] Epoch: 1000/1000, loss: 0.857502, (IMP) 0.857502, (SPA) 1.0 x 0.000000: 100%|██████████| 1000/1000 [00:02<00:00, 453.57it/s]


fold 0, median correlation: 0.236695587855856 (TransImpCls)


[TransImp] Epoch: 1000/1000, loss: 0.851238, (IMP) 0.851238, (SPA) 1.0 x 0.000000: 100%|██████████| 1000/1000 [00:03<00:00, 301.91it/s]


fold 0, median correlation: 0.2406388653036972 (TransImp)


[TransImp] Epoch: 1000/1000, loss: 0.868970, (IMP) 0.864923, (SPA) 1.0 x 0.004048: 100%|██████████| 1000/1000 [00:03<00:00, 281.45it/s]


fold 0, median correlation: 0.19977112200921238 (TransImpClsSpa)
			0.14062728667366875 (stPlus)


In [None]:
df_transImp.to_csv('../../output/melanoma_melanoma_transImpute.csv')
df_transImpSpa.to_csv('../../output/melanoma_melanoma_transImpSpa.csv')
df_transImpCls.to_csv('../../output/melanoma_melanoma_transImpCls.csv')
df_transImpClsSpa.to_csv('../../output/melanoma_melanoma_transImpClsSpa.csv')
df_spaGE_res.to_csv('../../output/melanoma_melanoma_spaGE.csv')
df_stplus_res.to_csv('../../output/melanoma_melanoma_stPlus.csv')
df_tangram_res.to_csv('../../output/melanoma_melanoma_Tangram.csv')


0

In [None]:
dict_df = {
            "TransImp":df_transImp, 
           "TransImpSpa":df_transImpSpa, 
           "TransImpCls":df_transImpCls,
           "TransImpClsSpa":df_transImpClsSpa,
        #    "spaGE": df_spaGE_res, "stPlus": df_stplus_res,
            "Tangram":df_tangram_res
            }
# spa_adata.X = spa_adata.X.toarray()
sq.gr.spatial_autocorr(
    spa_adata,
    n_jobs=10,
)
sq.gr.spatial_autocorr(
    spa_adata,
    n_jobs=10,
    mode='geary',
)

dict_adata = {name: compute_autocorr(spa_adata[:, raw_shared_gene].copy(), df) for name, df in dict_df.items()}


In [None]:
from sklearn.metrics import mean_squared_error
moranIs = {name:mean_squared_error(spa_adata.uns['moranI'].loc[raw_shared_gene].I, imp_adata.uns['moranI'].loc[raw_shared_gene].I) for name, imp_adata in dict_adata.items()}
gearyCs = {name:mean_squared_error(spa_adata.uns['gearyC'].loc[raw_shared_gene].C, imp_adata.uns['gearyC'].loc[raw_shared_gene].C) for name, imp_adata in dict_adata.items()}

print("Mean Squared Error\nMoran's I:\n")
print("\n".join([f"\tTrue vs {method}: {score:.6f}" for method, score in moranIs.items()]))
print("Geary's C:\n")
print("\n".join([f"\tTrue vs {method}: {score:.6f}" for method, score in gearyCs.items()]))


Mean Squared Error
Moran's I:

	True vs TransImp: 0.162054
	True vs TransImpSpa: 0.005394
	True vs TransImpCls: 0.270401
	True vs TransImpClsSpa: 0.004791
	True vs Tangram: 0.270404
Geary's C:

	True vs TransImp: 0.162223
	True vs TransImpSpa: 0.005548
	True vs TransImpCls: 0.270203
	True vs TransImpClsSpa: 0.005039
	True vs Tangram: 0.272670
