In [1]:
import numpy as np
import pandas as pd

import anndata as ad
import scanpy as sc
import squidpy as sq
import tangram as tg

from scipy import sparse
import gc
import os
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import rcParams
import matplotlib.colors as clr

  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
import torch
torch.set_float32_matmul_precision("high")
torch.cuda.is_available()

True

In [3]:
from scipy import sparse
from scipy import io
import os
import gzip

In [4]:
merfish = sc.read_h5ad("data/ns-atlas.merfish_baysor.scanvi_integrated.cellcharter.anndata.annotated.h5ad")

In [5]:
merfish.X = merfish.layers['counts'].copy()
sc.pp.normalize_total(merfish)
merfish.layers['norm'] = merfish.X.copy()

In [6]:
merfish.var_names = merfish.var_names.str.lower().values.tolist()
gc.collect()

39681

For training, we want to make sure that we train with the same genes across all samples. We will run each imaging batch separately, so make sure the genes are detected not only across gene panels but have at least one count across all batches.

In [7]:
merfish_batches = np.unique(merfish.obs['batch'].values.tolist())

In [8]:
sc.pp.highly_variable_genes(merfish, layer='log1p', n_top_genes=len(merfish.var_names), batch_key='batch', subset=False)

In [16]:
merfish.var

Unnamed: 0,mean,std,highly_variable,means,dispersions,dispersions_norm,highly_variable_nbatches,highly_variable_intersection,gene_symbol,gene_panel,pathways,cell_category,cell_type,cell_type.detailed,Notes,in_both_panels
aadac,0.001948,0.038527,True,0.003182,0.292109,-0.302599,17,False,AADAC,NSv2,,Stroma,Fib,Fib,,False
abi3bp,0.042268,0.181868,True,0.068665,0.502537,-0.315946,25,True,ABI3BP,Both,,Epithelia,Ecc Gland,Ecc Gland,Ecc gland,True
acan,0.001811,0.035629,True,0.002572,0.334071,-0.518449,25,True,ACAN,Both,,Stroma,DS,DS,DS,True
ackr1,0.059810,0.274727,True,0.124364,1.183953,1.171005,25,True,ACKR1,Both,,Stroma,EC,EC,EC,True
ackr2,0.009997,0.087661,True,0.016648,0.498767,0.034042,25,True,ACKR2,Both,,Stroma,EC,EC,EC,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
wnt5a,0.015567,0.101244,True,0.022572,0.371165,-0.389344,25,True,WNT5A,Both,Wnt L-R,,,,,True
wnt7b,0.008498,0.082435,True,0.013407,0.449020,-0.124062,25,True,WNT7B,Both,Wnt L-R,,,,,True
xcl1,0.001085,0.027053,True,0.001564,0.298894,-0.644429,25,True,XCL1,Both,,Immune,NK,NK,,True
znf331,0.039304,0.168637,True,0.057845,0.491066,-0.215072,25,True,ZNF331,Both,,Immune,T Cell,T Cell,T cell (RGCC),True


In [10]:
training_genes = merfish.var['in_both_panels'] & merfish.var['highly_variable_intersection']
training_genes.value_counts()

True     425
False    137
Name: count, dtype: int64

In [11]:
training_genes = training_genes[training_genes].index.values

For imputation, the scRNAseq object as-is is too large to use as a reference. To ease memory constraints and maximize biologically informative predictions, we need to downsample genes in our reference.

In [17]:
scrna = sc.read_h5ad("data/normal_skin.scrna.harmony.integrated.reclustered.annotated.mini.h5ad")
## save the correct case over here
scrna.var['gene_symbol'] = scrna.var_names.values.tolist().copy()
all_genes = scrna.var_names.values.tolist().copy()

In [18]:
scrna.X = scrna.layers['counts'].copy()
sc.pp.normalize_total(scrna)
scrna.layers['norm'] = scrna.X.copy()

In [19]:
scrna.var_names = scrna.var_names.str.lower().values.tolist()

In [20]:
scrna

AnnData object with n_obs × n_vars = 273178 × 5688
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'pct.mito', 'pct.ribo', 'pct.hemo', 'study_id', 'sample_barcode', 'donor_id', 'donor_sex', 'donor_age', 'anatomic_site', 'reported.cell_type', 'anatomic_site.detailed', 'harmony.snn_res.0.2', 'harmony.snn_res.0.4', 'harmony.snn_res.0.5', 'harmony.snn_res.0.6', 'harmony.snn_res.0.8', 'harmony.snn_res.1', 'harmony.snn_res.1.2', 'harmony.snn_res.1.5', 'harmony.snn_res.2', 'harmony.snn_res.2.5', 'seurat_clusters', 'cell_barcode', 'cell_type.broad', 'cell_category', 'cell_type', 'cell_type.broad.res_0.2', 'cell_type.detailed', 'cell_type.reclustered', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'pct_counts_hb', 'n_counts', 'leiden_0.5', 'leiden_1', 'leiden_1.5', 'leiden_2', 'leiden_2.5'
    var: 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_var

In [21]:
PYTORCH_CUDA_ALLOC_CONF={"expandable_segments":True}
torch.cuda.empty_cache()
gc.collect()

10433

In [22]:
cluster_label = "leiden_2.5"

In [26]:
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import rcParams
from matplotlib.pyplot import rc_context
from sklearn import metrics

In [27]:
auc_score_list = []

In [28]:
for batch_id in merfish_batches:
    print(batch_id)
    outdir = f"data/tangram/tangram_outs/"
    os.makedirs(outdir, exist_ok=True)
    ad_sp = merfish[merfish.obs['batch'] == batch_id].copy()
    sc.pp.filter_genes(ad_sp, min_counts=1)
    ad_sp.X = ad_sp.layers['norm'].copy()
    gc.collect()

    tg.pp_adatas(scrna, ad_sp, genes = training_genes)
    gc.collect()
    
    ad_map = tg.map_cells_to_space(scrna, ad_sp,  
                                   density_prior='uniform', 
                                   device='cuda:0', 
                                   mode='clusters',
                                   num_epochs=500,
                                   cluster_label = cluster_label)
    ad_map.write_h5ad(f"{outdir}/{batch_id}.tangram.clusters_mode.mapped_object.5k_mini.h5ad", compression='gzip')
    ad_map.uns['train_genes_df'].to_csv(f'{outdir}/{batch_id}.tangram.clusters_mode.5k_mini.training_stats.csv')
    
    ad_ge = tg.project_genes(adata_map=ad_map, adata_sc=scrna, cluster_label=cluster_label)
    ad_ge.X = sparse.csr_matrix(ad_ge.X, shape=(ad_ge.shape[0], len(all_genes)), dtype=np.float64)
    ad_ge.write_h5ad(f"{outdir}/{batch_id}.tangram.clusters_mode.5k_mini.imputed_object.h5ad")
    
    ad_sp.uns['overlap_genes'] = np.sort(ad_sp.uns['overlap_genes'])
    ad_ge.uns['overlap_genes'] = np.sort(ad_ge.uns['overlap_genes'])
    scrna.uns['overlap_genes'] = np.sort(scrna.uns['overlap_genes'])
    
    df_all_genes = tg.compare_spatial_geneexp(ad_ge, ad_sp, scrna)
    df_all_genes.to_csv(f"{outdir}/{batch_id}.tangram.clusters_mode.5k_mini.imputation_genestats.csv")

    auc_score = metrics.auc(df_all_genes['score'], df_all_genes['sparsity_sp'])
    auc_score_list.append(auc_score)
    
    ax = sns.regplot(
        data=df_all_genes, 
        x="score", y="sparsity_sp", 
        logistic=True, color=".3", 
        line_kws=dict(color="r"),
        label = 'AUC = %0.2f' % auc_score
    )
    ax.set_title(batch_id + ' (AUC = %0.2f' % auc_score + ")")
    plt.savefig(f"figures/tangram/qc/{batch_id}.tangram.clusters_mode.5k_mini.imputation.auc_curve.png")
    plt.close()
    
    torch.cuda.empty_cache()
    gc.collect()
    

MSSM_00
Score: 0.112, KL reg: 0.040
Score: 0.290, KL reg: 0.003
Score: 0.292, KL reg: 0.003
Score: 0.292, KL reg: 0.003
Score: 0.292, KL reg: 0.003
MSSM_01
Score: 0.111, KL reg: 0.040
Score: 0.297, KL reg: 0.004
Score: 0.298, KL reg: 0.003
Score: 0.298, KL reg: 0.003
Score: 0.298, KL reg: 0.003
MSSM_02
Score: 0.119, KL reg: 0.040
Score: 0.313, KL reg: 0.002
Score: 0.314, KL reg: 0.002
Score: 0.315, KL reg: 0.002
Score: 0.315, KL reg: 0.002
MSSM_03
Score: 0.141, KL reg: 0.040
Score: 0.348, KL reg: 0.003
Score: 0.350, KL reg: 0.003
Score: 0.350, KL reg: 0.003
Score: 0.350, KL reg: 0.003
MSSM_04
Score: 0.121, KL reg: 0.040
Score: 0.313, KL reg: 0.003
Score: 0.315, KL reg: 0.003
Score: 0.315, KL reg: 0.003
Score: 0.315, KL reg: 0.003
MSSM_05
Score: 0.098, KL reg: 0.040
Score: 0.287, KL reg: 0.004
Score: 0.288, KL reg: 0.004
Score: 0.289, KL reg: 0.004
Score: 0.289, KL reg: 0.004
MSSM_06
Score: 0.110, KL reg: 0.039
Score: 0.299, KL reg: 0.003
Score: 0.300, KL reg: 0.003
Score: 0.300, KL reg

In [29]:
gc.collect()

71329

In [30]:
ad_ge.var_names = ad_ge.var['gene_symbol'].values.tolist().copy()

In [31]:
ad_ge.var

Unnamed: 0,mt,ribo,hb,n_cells_by_counts,mean_counts,pct_dropout_by_counts,total_counts,highly_variable,means,dispersions,dispersions_norm,highly_variable_nbatches,highly_variable_intersection,gene_symbol,n_cells,sparsity,is_training
A1BG,False,False,False,27896,0.630981,89.788343,172370.0,False,0.135896,0.891595,-0.062615,14,True,A1BG,27896,0.897883,False
A2M,False,False,False,25439,0.434189,90.687757,118611.0,True,0.313068,2.187632,1.869681,14,True,A2M,25439,0.906878,False
A2M-AS1,False,False,False,1337,0.087961,99.510576,24029.0,True,0.007811,1.071408,0.299354,14,True,A2M-AS1,1337,0.995106,False
A2ML1,False,False,False,7791,0.089645,97.148013,24489.0,True,0.029538,1.643422,1.450835,14,True,A2ML1,7791,0.971480,False
A4GALT,False,False,False,13819,0.078143,94.941394,21347.0,True,0.078814,1.068986,0.294479,14,True,A4GALT,13819,0.949414,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZSCAN4,False,False,False,96,0.000780,99.964858,213.0,True,0.000426,1.126660,0.410579,11,False,ZSCAN4,96,0.999649,False
ZSWIM4,False,False,False,7241,0.041592,97.349347,11362.0,True,0.029834,0.972179,0.099604,14,True,ZSWIM4,7241,0.973493,False
ZSWIM5,False,False,False,1267,0.020917,99.536200,5714.0,True,0.007878,1.470241,1.102215,14,True,ZSWIM5,1267,0.995362,False
ZSWIM6,False,False,False,31720,0.233932,88.388523,63905.0,True,0.154845,1.033158,0.222356,14,True,ZSWIM6,31720,0.883885,False


In [33]:
gc.collect()

0

In [34]:
import gc
import glob

In [35]:
flist = glob.glob("data/tangram/tangram_outs/*.clusters_mode.5k_mini.imputed_object.h5ad")
adata_list = []
for f in flist:
    print(f)
    sample_adata = ad.read_h5ad(f, backed='r')
    adata_list.append(sample_adata)
    del sample_adata
    gc.collect()
gc.collect()

from anndata.experimental.multi_files import AnnCollection
dataset = AnnCollection(adata_list, join_vars='inner', join_obs='outer', join_obsm='inner', label='dataset')
adata = dataset[dataset.obs_names].to_adata()
del dataset
gc.collect()

adata.obsm['spatial'] =  adata.obs[['center_x', 'center_y']].values.copy()
adata.layers['tangram'] = adata.X.copy()
sc.pp.scale(adata, max_value=10)
adata.layers['scaled'] = adata.X.copy()
gc.collect()
adata.write_h5ad("data/tangram/merfish.tangram_imputed.5k_mini.h5ad")

data/tangram/tangram_outs/MSSM_19.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_11.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_02.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_04.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_21.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_24.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_23.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_22.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_16.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_08.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_12.tangram.clusters_mode.5k_mini.imputed_object.h5ad
data/tangram/tangram_outs/MSSM_17.tangram.clusters_mode.5k_mini.imputed_obje

In [36]:
gc.collect()

0

In [37]:
adata.var_names = scrna.var['gene_symbol'].values.tolist()

In [38]:
adata.write_h5ad("data/tangram/merfish.tangram_imputed.5k_mini.h5ad")

In [4]:
adata = sc.read_h5ad("data/tangram/merfish.tangram_imputed.5k_mini.h5ad")

In [5]:
import numpy as np
from anndata import AnnData

def minmax_normalize_by_gene_per_sample(adata: AnnData, sample_key: str = 'sample_barcode', layer_name: str = 'minmax') -> None:
    """
    Min-max normalize gene expression per gene (column) within each sample group,
    and store the result in adata.layers[layer_name].

    Parameters
    ----------
    adata : AnnData
        AnnData object with .X as dense or sparse array.
    sample_key : str
        Column in adata.obs identifying the sample grouping (e.g., 'sample_barcode').
    layer_name : str
        Name of the .layers entry to store the result (default: 'minmax').

    Returns
    -------
    None
        Modifies adata in-place by writing to adata.layers[layer_name].
    """
    from scipy.sparse import issparse

    if issparse(adata.X):
        X = adata.X.toarray()
    else:
        X = adata.X.copy()

    norm_X = np.zeros_like(X)
    sample_barcodes = adata.obs[sample_key].unique()

    for barcode in sample_barcodes:
        mask = adata.obs[sample_key].values == barcode
        sub_X = X[mask]  # shape: [cells_in_sample, genes]

        min_vals = sub_X.min(axis=0, keepdims=True)  # shape: [1, genes]
        max_vals = sub_X.max(axis=0, keepdims=True)  # shape: [1, genes]

        denom = np.where((max_vals - min_vals) == 0, 1, max_vals - min_vals)
        norm_X[mask] = (sub_X - min_vals) / denom

    adata.layers[layer_name] = norm_X

In [6]:
minmax_normalize_by_gene_per_sample(adata)

In [7]:
gc.collect()

84193

In [8]:
adata.write_h5ad("data/tangram/merfish.tangram_imputed.5k_mini.h5ad")

In [6]:
test_genes = ["KRT1", "KRT5", "KRT14", "KRT10", "KRT15", "KRT17", 
              "DSG1", "KRTDAP", "DMKN", "SBSN", "FABP5", "COL1A2",
              "SCGB1B2P", "KRT23", "MGST1", "KRT19", "TNFSFR1A", "TNF", 
              'CXCL12', "CXCR4", 'PECAM1']

In [6]:
import matplotlib.colors as clr
gray2blue = sns.blend_palette([sns.xkcd_rgb["light grey"], sns.xkcd_rgb["electric blue"]], as_cmap=True)
gray2red = sns.blend_palette([sns.xkcd_rgb["light grey"], sns.xkcd_rgb["bright red"]], as_cmap=True)
gray2teal = sns.blend_palette([sns.xkcd_rgb["light grey"], sns.xkcd_rgb["turquoise"]], as_cmap=True)

In [5]:
marker_genes = ["KRT1", "KRT5", "KRT14", "KRT10", "KRT15", "KRT17", 
                "DSG1", "KRTDAP", "DMKN", "SBSN", "FABP5", "COL1A2",
                "SCGB1B2P", "KRT23", "MGST1", "KRT19", "COL1A1", 
                "KRT79", "ADIPOQ", "MRC1", "CCL19", "APOE", 
                "CD4", "PDGFRA", "FLG", "ACKR1", "CLDN5", "VWF", "VIM",
               "PLVAP", "PODXL", "PROX1"]

In [7]:
ligand_genes = ["CD74", "CD44", "BSG", "TNFRSF1A", "TNFRSF1B", "CXCR4", "ITGB1", "ITGA3"]

In [8]:
receptor_genes = ["COMP", "CXCL12", "MIF", "TNF", "PPIA"]

In [10]:
import matplotlib.pyplot as plt
import seaborn as sns
def plot_spatial_featureplot_by_sample(
    adata,
    gene_key,
    sample_key="sample_barcode",
    layer=None,
    cmap="magma",
    figsize=(5, 5),    
    output_directory='.',
    file_prefix='',
    dpi=300,
    show=False
):
    samples = adata.obs[sample_key].cat.categories.unique().sort_values()
    n_samples = len(samples)

    ncols = min(n_samples, 6)
    nrows = (n_samples + ncols - 1) // ncols

    if layer != None:
        adata.X = adata.layers[layer]

    fig, axes = plt.subplots(nrows, ncols, figsize=(figsize[0]*ncols, figsize[1]*nrows))
    axes = np.array(axes).reshape(-1)

    for i, sample in enumerate(samples):
        ax = axes[i]
        adata_sub = adata[adata.obs[sample_key] == sample]
        ax.set_aspect('equal', adjustable='datalim')
        ax.set_xticks([])
        ax.set_yticks([])

        sc.pl.embedding(adata_sub, 
                        basis='spatial',
                        title=f"{sample}", 
                        color=gene_key,
                        cmap=cmap,
                        ax=ax, 
                        legend_loc='on data', 
                        show=False)
        
        ax.axis("off")

    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    # Add a global panel title
    fig.suptitle(f"{gene_key}", fontsize=14, y=0.99)

    plt.tight_layout()
    plt.savefig(f"{output_directory}/{file_prefix}.{gene_key}.spatial_plot.png", 
                bbox_inches='tight', dpi=dpi)
    if show:
        plt.show()
    plt.close()

In [17]:
from tqdm import tqdm
with tqdm(total=len(marker_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in marker_genes:
        if gene in adata.var_names:
            plot_spatial_featureplot_by_sample(
               adata,
               output_directory="figures/tangram/feature_spatial/",
               file_prefix=f'tangram_imputed.minmax',
               layer='minmax',
               gene_key=gene,
               cmap=gray2red,
               sample_key='sample_barcode')
            pbar.update(1)
            gc.collect()
    gc.collect()
        

Plotting Genes: 100%|█████████████████████████████████████████████| 32/32 [28:14<00:00, 52.94s/Gene]


In [18]:
with tqdm(total=len(marker_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in marker_genes:
        if gene in adata.var_names:
            plot_spatial_featureplot_by_sample(
               adata,
               output_directory="figures/tangram/feature_spatial/",
               file_prefix=f'tangram_imputed.scaled',
               layer='scaled',
               gene_key=gene, 
               cmap=gray2red,
               sample_key='sample_barcode')
            pbar.update(1)
            gc.collect()
    gc.collect()
        

Plotting Genes: 100%|█████████████████████████████████████████████| 32/32 [28:35<00:00, 53.60s/Gene]


In [13]:
from tqdm import tqdm
with tqdm(total=len(ligand_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in ligand_genes:
        plot_spatial_featureplot_by_sample(
           adata,
           output_directory="figures/tangram/feature_spatial/",
           file_prefix=f'tangram_imputed.minmax',
           layer='minmax',
           gene_key=gene,
           cmap=gray2red,
           sample_key='sample_barcode')
        pbar.update(1)
        gc.collect()
    gc.collect()


Plotting Genes: 100%|███████████████████████████████████████████████| 8/8 [06:44<00:00, 50.52s/Gene]


In [14]:
from tqdm import tqdm
with tqdm(total=len(ligand_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in ligand_genes:
        plot_spatial_featureplot_by_sample(
           adata,
           output_directory="figures/tangram/feature_spatial/",
           file_prefix=f'tangram_imputed.scaled',
           layer='scaled',
           gene_key=gene,
           cmap=gray2red,
           sample_key='sample_barcode')
        pbar.update(1)
        gc.collect()
    gc.collect()


Plotting Genes: 100%|███████████████████████████████████████████████| 8/8 [06:50<00:00, 51.33s/Gene]


In [15]:
from tqdm import tqdm
with tqdm(total=len(receptor_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in receptor_genes:
        plot_spatial_featureplot_by_sample(
           adata,
           output_directory="figures/tangram/feature_spatial/",
           file_prefix=f'tangram_imputed.minmax',
           layer='minmax',
           gene_key=gene,
           cmap=gray2blue,
           sample_key='sample_barcode')
        pbar.update(1)
        gc.collect()
    gc.collect()

Plotting Genes: 100%|███████████████████████████████████████████████| 5/5 [04:15<00:00, 51.16s/Gene]


In [16]:
from tqdm import tqdm
with tqdm(total=len(receptor_genes), desc="Plotting Genes", unit="Gene", ncols=100, leave=True) as pbar:
    for gene in receptor_genes:
        plot_spatial_featureplot_by_sample(
           adata,
           output_directory="figures/tangram/feature_spatial/",
           file_prefix=f'tangram_imputed.scaled',
           layer='scaled',
           gene_key=gene,
           cmap=gray2blue,
           sample_key='sample_barcode')
        pbar.update(1)
        gc.collect()
    gc.collect()

Plotting Genes: 100%|███████████████████████████████████████████████| 5/5 [04:14<00:00, 50.90s/Gene]
