In [None]:
import anndata
import torch
import stPlus
import os
import random
import warnings
import pickle

import squidpy as sq
import numpy as np
import scanpy as sc
import pandas as pd
import spatialdm as sdm

from sklearn.model_selection import KFold
from matplotlib import pyplot as plt
from transpa.eval_util import calc_corr
from transpa.util import expTransImp, leiden_cluster, compute_autocorr, plot_genes
# from benchmark import SpaGE_impute, Tangram_impute
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score, homogeneity_score
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import StandardScaler
from exp_spatialdm import spatialdm


warnings.filterwarnings('ignore')
pre_datapaths = ["../../output/preprocessed_dataset/seqFISH_single_cell.pkl",
                 "../../output/preprocessed_dataset/merfish_moffit.pkl",
                 "../../output/preprocessed_dataset/osmFISH_allenvisp.pkl",
                 "../../output/preprocessed_dataset/starmap_allenvisp.pkl"
                 ]

seed = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
pre_datapath = pre_datapaths[0]
with open(pre_datapath, 'rb') as infile:
    spa_adata, scrna_adata, raw_spatial_df, raw_scrna_df, raw_shared_gene = pickle.load(infile)
    cls_key = 'leiden'
    classes = scrna_adata.obs[cls_key]
    ct_list = np.unique(classes)

In [3]:
sc.pp.highly_variable_genes(scrna_adata, n_top_genes=3000)
scrna_adata

AnnData object with n_obs × n_vars = 32844 × 29452
    obs: 'cell', 'barcode', 'sample', 'pool', 'stage', 'sequencing.batch', 'theiler', 'doub.density', 'doublet', 'cluster', 'cluster.sub', 'cluster.stage', 'cluster.theiler', 'stripped', 'celltype', 'colour', 'sizeFactor', 'leiden'
    var: 'ENSEMBL', 'SYMBOL', 'SymbolUniq', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'hvg'

In [4]:
# Filter small/tiny celltypes (# cells <= 10)
# Find celltype specific genes
tp, counts = np.unique(scrna_adata.obs.celltype, return_counts=True)
sub_scrna_adata = scrna_adata[scrna_adata.obs.celltype.isin([_tp for _tp, _ct in zip(tp, counts) if _ct > 10]),scrna_adata.var['highly_variable']]
sc.tl.rank_genes_groups(sub_scrna_adata, 'celltype', method='wilcoxon')
sc.tl.filter_rank_genes_groups(sub_scrna_adata, min_fold_change=2)

In [5]:
# Fetch the top 30 scored celltype marker genes
top_k = 30
candidate_genes = set()
for i in range(top_k):
    for g in sub_scrna_adata.uns['rank_genes_groups_filtered']['names'][i]:
        if type(g) != str: continue
        candidate_genes.add(g)
print(f"# unique sc celltype marker genes: {len(candidate_genes)}")

# unique sc celltype marker genes: 363


In [6]:
extra_genes = np.setdiff1d(list(candidate_genes), raw_shared_gene)
print(f'# extra genes not in Spa ST: {len(extra_genes)}')

# extra genes not in Spa ST: 217


In [7]:
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
res = expTransImp(
        df_ref=raw_scrna_df,
        df_tgt=raw_spatial_df,
        train_gene=raw_shared_gene,
        test_gene=np.concatenate([raw_shared_gene, extra_genes]),
        n_simulation=200,
        signature_mode='cell',
        mapping_mode='lowrank',
        classes=classes,
        n_epochs=2000,
        seed=seed,
        device=device)

[TransImp] Epoch: 2000/2000, loss: 0.781017, (IMP) 0.781017: 100%|██████████| 2000/2000 [00:30<00:00, 65.41it/s]


In [8]:
spa_adata

AnnData object with n_obs × n_vars = 57536 × 351
    obs: 'uniqueID', 'embryo', 'pos', 'z', 'x_global', 'y_global', 'x_global_affine', 'y_global_affine', 'embryo_pos', 'embryo_pos_z', 'Area', 'UMAP1', 'UMAP2', 'celltype_mapped_refined', 'segmentation_vertices_x_global_affine', 'segmentation_vertices_y_global_affine'
    uns: 'log1p', 'spatial_neighbors'
    obsm: 'spatial'
    layers: 'normalized'
    obsp: 'spatial_connectivities', 'spatial_distances'

In [9]:
# Spatial cluster: AgglomerativeClustering + spatial constrain
n_clusters = len(spa_adata.obs.celltype_mapped_refined.unique())
def spatial_cluster(expr, prefix):
    spa_cluster_metrics = {}
    pred_clss = AgglomerativeClustering(n_clusters=n_clusters, 
                                        connectivity=spa_adata.obsp['spatial_connectivities'],
                                        ).fit_predict(expr)
    ars = adjusted_rand_score(spa_adata.obs.celltype_mapped_refined, pred_clss)
    amis = adjusted_mutual_info_score(spa_adata.obs.celltype_mapped_refined, pred_clss)
    homo = homogeneity_score(spa_adata.obs.celltype_mapped_refined, pred_clss)
    nmi = normalized_mutual_info_score(spa_adata.obs.celltype_mapped_refined, pred_clss)
    spa_cluster_metrics[prefix] = {"ARS":ars, "AMIS":amis, 'HOMO': homo, 'NMI':nmi}
    return pd.DataFrame(spa_cluster_metrics)

In [10]:
# column mask for select sc marker genes
candidate_msk = [True  if g in candidate_genes else False for g in np.concatenate([raw_shared_gene, extra_genes])]
np.sum(candidate_msk)

363

In [11]:
"""
Spatial clustering on 
1. seqFISH raw ST data (351 genes)
2. imputed sc celltype marker genes (top 30, 363 genes = 217 extra + 46 overlapped imputed)
3-6. imputed genes selected by prediction confidence scores, top 500, 400, 300, 200, 100, 50

Return averaged clustering scores against manual annotations `celltype_mapped_refined` in spa_adata.obs
"""
df_cls = spatial_cluster(spa_adata.X.toarray(), "SeqFISH_Raw (351) vs Ground.Annotation")
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, candidate_msk], f"SCImputedMarkers ({len(candidate_genes)}) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0], f"SCImputedAll ({len(candidate_msk)}) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:500]], "SCImputedTopConfident (500) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:400]], "SCImputedTopConfident (400) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:300]], "SCImputedTopConfident (300) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:200]], "SCImputedTopConfident (200) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:100]], "SCImputedTopConfident (100) vs Ground.Annotation")], axis=1)
df_cls = pd.concat([df_cls, spatial_cluster(res[0][:, np.argsort(res[1])[:50]], "SCImputedTopConfident (50) vs Ground.Annotation")], axis=1)
df_cls

Unnamed: 0,SeqFISH_Raw (351) vs Ground.Annotation,SCImputedMarkers (363) vs Ground.Annotation,SCImputedAll (568) vs Ground.Annotation,SCImputedTopConfident (500) vs Ground.Annotation,SCImputedTopConfident (400) vs Ground.Annotation,SCImputedTopConfident (300) vs Ground.Annotation,SCImputedTopConfident (200) vs Ground.Annotation,SCImputedTopConfident (100) vs Ground.Annotation,SCImputedTopConfident (50) vs Ground.Annotation
AMIS,0.339416,0.34586,0.336304,0.345257,0.347428,0.35308,0.343398,0.326686,0.306348
ARS,0.183854,0.328984,0.324906,0.333656,0.346293,0.353473,0.341988,0.311472,0.248058
HOMO,0.350695,0.329384,0.324979,0.333307,0.335329,0.345384,0.33406,0.307818,0.269785
NMI,0.340482,0.347003,0.337451,0.346389,0.348556,0.354185,0.344524,0.327882,0.307635


In [12]:
df_cls.mean()

SeqFISH_Raw (351) vs Ground.Annotation              0.303612
SCImputedMarkers (363) vs Ground.Annotation         0.337808
SCImputedAll (568) vs Ground.Annotation             0.330910
SCImputedTopConfident (500) vs Ground.Annotation    0.339652
SCImputedTopConfident (400) vs Ground.Annotation    0.344402
SCImputedTopConfident (300) vs Ground.Annotation    0.351530
SCImputedTopConfident (200) vs Ground.Annotation    0.340993
SCImputedTopConfident (100) vs Ground.Annotation    0.318464
SCImputedTopConfident (50) vs Ground.Annotation     0.282956
dtype: float64

In [14]:
df_cls.index.name = 'metric'
df_cls.to_csv("../../output/segfish_cluster_with_extra_genes_top30.csv")