In [1]:
import scipy.sparse
import matplotlib
from functools import partial
from scipy import interpolate
from itertools import cycle
import pandas as pd
import pyro
import torch
import anndata
import scanpy as sc 
import numpy as np
import sys
import matplotlib.pyplot as plt 
import matplotlib as mpl

sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/BayraktarLab/scvi-tools/')

import scvi


%config InlineBackend.figure_format='retina'

In [2]:
table_dir = '/nfs/team205/vk7/sanger_projects/collaborations/fetal_gut_mapping/results/hyperparameters/'

NOTEBOOK_NUMBER = 0

torch.cuda.set_device(f'cuda:{NOTEBOOK_NUMBER}')

# Set paths to data and results used through the document:
sp_data_folder = '/nfs/team205/vk7/sanger_projects/large_data/gut_kj_re/oxford_visium/'
sc_data_folder = '/nfs/team205/vk7/sanger_projects/large_data/gut_kj_re/'
results_folder = '/nfs/team205/vk7/sanger_projects/collaborations/fetal_gut_mapping/results/'
sc_results_folder = '/nfs/team205/vk7/sanger_projects/cell2location_proj/notebooks/results/gut/'
annotations_folder = '/nfs/team205/vk7/sanger_projects/collaborations/fetal_gut_mapping/results/tissue_annotation/oxford/'

scvi_run_name_global = f'{results_folder}/hyperparameters/c2l_v3_nonamortised_fulldata_epochs20k_lr0002_Adam_oxford_adult_paed_ref_signatures_N25'

regression_model_output = 'v1_ye_signatures_lr0002_Adam'
reg_path = f'{results_folder}regression_model/{regression_model_output}/'

In [3]:
conditions = pd.read_csv(f'{table_dir}/param_tables/table_{NOTEBOOK_NUMBER}.csv')

In [4]:
## snRNAseq reference (raw counts)
adata_snrna_raw = sc.read(f'{reg_path}sc.h5ad')

adata_snrna_raw2 = anndata.read_h5ad(sc_data_folder + "FINAL_OBJECT_raw_nosoupx.h5ad")
adata_snrna_raw.obsm = adata_snrna_raw2[adata_snrna_raw.obs_names,:].obsm

  res = method(*args, **kwargs)


In [5]:
# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_snrna_raw.varm.keys():
    inf_aver = adata_snrna_raw.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}' 
                                    for i in adata_snrna_raw.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_snrna_raw.var[[f'means_per_cluster_mu_fg_{i}' 
                                    for i in adata_snrna_raw.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_snrna_raw.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

Unnamed: 0,Activated CD4 T,Activated CD8 T,Adult Glia,BEST2+ Goblet cell,BEST4+ epithelial
ENSG00000187634,0.000475,0.000341,0.008211,0.001663,0.001148
ENSG00000188976,0.056281,0.082341,0.061379,0.256258,0.149516
ENSG00000187583,0.002949,0.01627,0.003423,0.007914,0.001061
ENSG00000188290,0.04108,0.058555,0.128032,0.130205,1.816961
ENSG00000187608,0.123889,0.235275,0.197328,0.110768,0.701299


In [6]:
def read_and_qc(sample_name, path=sp_data_folder):
    r""" This function reads the data for one 10X spatial experiment into the anndata object.
    It also calculates QC metrics. Modify this function if required by your workflow.
    
    :param sample_name: Name of the sample
    :param path: path to data
    """
    
    adata = sc.read_visium(path + str(sample_name),
                           count_file='filtered_feature_bc_matrix.h5', load_images=True)
    adata.obs['sample'] = sample_name
    adata.var['SYMBOL'] = adata.var_names
    adata.var.rename(columns={'gene_ids': 'ENSEMBL'}, inplace=True)
    adata.var_names = adata.var['ENSEMBL']
    adata.var.drop(columns='ENSEMBL', inplace=True)

    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(adata, inplace=True)
    adata.var['mt'] = [gene.startswith('mt-') for gene in adata.var['SYMBOL']]
    adata.obs['mt_frac'] = adata[:, adata.var['mt'].tolist()].X.sum(1).A.squeeze()/adata.obs['total_counts']

    # add sample name to obs names
    adata.obs["sample"] = [str(i) for i in adata.obs['sample']]
    adata.obs_names = adata.obs["sample"] \
                          + '_' + adata.obs_names
    adata.obs.index.name = 'spot_id'
    
    return adata

def select_slide(adata, s, s_col='sample'):
    r""" This function selects the data for one slide from the spatial anndata object.
    
    :param adata: Anndata object with multiple spatial experiments
    :param s: name of selected experiment
    :param s_col: column in adata.obs listing experiment name for each location
    """
    
    slide = adata[adata.obs[s_col].isin([s]), :].copy()
    s_keys = list(slide.uns['spatial'].keys())
    s_spatial = np.array(s_keys)[[s in k for k in s_keys]][0]

    slide.uns['spatial'] = {s_spatial: slide.uns['spatial'][s_spatial]}
    
    return slide

#######################
# Read the list of spatial experiments
sample_data = pd.DataFrame(['A1', 'A2'], 
                           columns=['sample_name'])

# Read the data into anndata objects
slides = []
for i in sample_data['sample_name']:
    slides.append(read_and_qc(i, path=sp_data_folder))

# Combine anndata objects together
adata = slides[0].concatenate(
    slides[1:],
    batch_key="sample",
    uns_merge="unique",
    batch_categories=sample_data['sample_name'],
    index_unique=None
)
#######################

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


In [7]:
# mitochondria-encoded (MT) genes should be removed for spatial mapping
adata.obsm['mt'] = adata[:, adata.var['mt'].values].X.toarray()
adata = adata[:, ~adata.var['mt'].values]

  res = method(*args, **kwargs)


In [8]:
# read histology-based annotations:
adata.obs['Annotation'] = ''

for s in adata.obs['sample'].unique():
    annot_ = pd.read_csv(f'{annotations_folder}{s}_annotation.csv', index_col='Barcode')
    annot_.index = [f'{s}_{i}' for i in annot_.index]
    
    adata.obs.loc[annot_.index, 'Annotation'] = annot_['Annotation']
    
adata.obs['Annotation'].value_counts(dropna=False)

Trying to set attribute `.obs` of view, copying.


Epithelial            1752
Other_tissue          1649
NaN                    596
LowQ_tissue            527
Stem_cell_zone         396
Lymphoid_structure      45
Name: Annotation, dtype: int64

In [9]:
# select locations from good quality tissue:
tissue_ind = adata.obs['Annotation'].isin(['Stem_cell_zone', 'Epithelial',
                                           'Other_tissue', 'Lymphoid_structure'])
print(sum(tissue_ind))

adata = adata[tissue_ind, :]
adata.obs['Annotation'].value_counts(dropna=False)

3842


Epithelial            1752
Other_tissue          1649
Stem_cell_zone         396
Lymphoid_structure      45
Name: Annotation, dtype: int64

In [10]:
adata_vis = adata.copy()
adata_vis.raw = adata_vis

In [11]:
# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
adata_vis = adata_vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

# prepare anndata for scVI model
scvi.data.setup_anndata(adata=adata_vis, batch_key="sample")
scvi.data.view_anndata_setup(adata_vis)

[34mINFO    [0m Using batches from adata.obs[1m[[0m[32m"sample"[0m[1m][0m                                              
[34mINFO    [0m No label_key inputted, assuming all cells have same label                           
[34mINFO    [0m Using data from adata.X                                                             
[34mINFO    [0m Computing library size prior per batch                                              
[34mINFO    [0m Successfully registered anndata object containing [1;34m3842[0m cells, [1;34m14388[0m vars, [1;34m2[0m batches,
         [1;34m1[0m labels, and [1;34m0[0m proteins. Also registered [1;34m0[0m extra categorical covariates and [1;34m0[0m extra
         continuous covariates.                                                              
[34mINFO    [0m Please do not further modify adata until model is trained.                          


  res = method(*args, **kwargs)
  res = method(*args, **kwargs)


In [None]:
for _, con in conditions.iterrows():
    # create and train the model
    scvi_run_name = f'{table_dir}/{con["names"]}'
    import pyro
    from scvi.external.cell2location._cell2location_v3_module import LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel
    mod = scvi.external.Cell2location(
        adata_vis, cell_state_df=inf_aver, 
        model_class=LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel,
        amortised=False,# single_encoder=True,
        N_cells_per_location=con["N_cells_per_location"],
        A_factors_per_location=con["A_factors_per_location"],
        B_groups_per_location=con["B_groups_per_location"],
        detection_alpha=con["detection_alpha"],
    ) 

    mod.train(max_epochs=20000, 
              batch_size=None, 
              train_size=1,
              plan_kwargs={'optim': pyro.optim.Adam({'lr': con['lr']})},
              use_gpu=True)

    # In this section, we export the estimated cell abundance (summary of the posterior distribution).
    adata_vis = mod.export_posterior(
        adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': False}
    )

    # Save model
    mod.save(f"{scvi_run_name}", overwrite=True)
    # can be loaded later like this:
    # mod = scvi.external.Cell2location.load(f"{scvi_run_name}", adata_vis)

    # Save anndata object with results
    adata_file = f"{scvi_run_name}/sp.h5ad"
    adata_vis.write(adata_file)

    # plot ELBO loss history during training, removing first 100 epochs from the plot
    mod.plot_history(5000)
    plt.savefig(f"{scvi_run_name}/training_ELBO_history_minus5k.png",
                bbox_inches='tight')
    plt.close()
    mod.plot_history(0)
    plt.savefig(f"{scvi_run_name}/training_ELBO_history_all.png",
                bbox_inches='tight')
    plt.close()
       
    # Examine reconstruction accuracy to assess if there are any issues with mapping
    # the plot should be roughly diagonal, strong deviations will signal problems
    mod.plot_QC()
    plt.savefig(f"{scvi_run_name}/reconstruction_accuracy_histogram.png",
                bbox_inches='tight')
    plt.close()

    # add 5% quantile, representing confident cell abundance, 'at least this amount is present', 
    # to adata.obs with nice names for plotting
    adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf']

    def plot_spatial_per_cell_type(adata, 
                                   cell_type=adata_vis.uns['mod']['factor_names'][0],
                                   samples=['A1', 'A2'],
                                  ncol=2, prefix=''):
        n_samples = len(samples)
        nrow = int(np.ceil(n_samples / ncol))
        fig, axs = plt.subplots(nrow, ncol, figsize=(24, 8))
        if nrow == 1:
            axs = axs.reshape((1, ncol))

        col_name = f'{prefix}{cell_type}'
        vmax = np.quantile(adata_vis.obs[col_name].values, 0.99999)
        adata_vis.obs[cell_type] = adata_vis.obs[col_name].copy()

        from itertools import chain
        ind = list(chain(*[[(i, j) for i in range(nrow)] for j in range(ncol)]))

        for i, s in enumerate(samples):
            sp_data_s = select_slide(adata, s, s_col='sample')
            sc.pl.spatial(sp_data_s, cmap='magma',
                          color=cell_type, 
                          size=1.3, img_key='hires', alpha_img=1,
                          vmin=0, vmax=vmax, ax=axs[ind[i][0],ind[i][1]], show=False
                                                )
            axs[ind[i][0],ind[i][1]].title.set_text(cell_type+'\n'+s)

        fig.tight_layout(pad=0.5)

        return fig

    fig_dir = f"{scvi_run_name}/spatial/"
    import os
    if not os.path.exists(fig_dir):
        os.mkdir(fig_dir)
    adata_vis.obs['total_cell_abundance'] = adata_vis.uns['mod']['post_sample_means']['w_sf'].sum(1).flatten()
    fig = plot_spatial_per_cell_type(adata_vis, cell_type='total_cell_abundance', prefix='');
    fig.savefig(f"{fig_dir}total_cell_abundance.png", bbox_inches='tight')
    fig.clear()
    plt.close(fig)

    adata_vis.obs['detection_y_s'] = adata_vis.uns['mod']['post_sample_q05']['detection_y_s']
    fig = plot_spatial_per_cell_type(adata_vis, cell_type='detection_y_s', prefix='');
    fig.savefig(f"{fig_dir}detection_y_s.png", bbox_inches='tight')
    fig.clear()
    plt.close(fig)

    fig = plot_spatial_per_cell_type(adata_vis, cell_type='total_counts', prefix='');
    fig.savefig(f"{fig_dir}total_RNA_counts.png", bbox_inches='tight')
    fig.clear()
    plt.close(fig)

    with mpl.rc_context({"axes.facecolor": "black"}):
        clust_names = adata_vis.uns['mod']['factor_names']

        for s in adata_vis.obs['sample'].unique():

            s_ind = adata_vis.obs['sample'] == s
            s_keys = list(adata_vis.uns['spatial'].keys())
            s_spatial = np.array(s_keys)[[s in i for i in s_keys]][0]

            fig = sc.pl.spatial(adata_vis[s_ind, :], cmap='magma',
                                color=clust_names, ncols=5, library_id=s_spatial,
                                size=1.3, img_key='hires', alpha_img=1,
                                vmin=0, vmax='p99.2',
                                return_fig=True, show=False)

            fig_dir = f"{scvi_run_name}/spatial/"
            if not os.path.exists(fig_dir):
                os.mkdir(fig_dir)
            fig_dir = f"{scvi_run_name}/spatial/per_sample/"
            if not os.path.exists(fig_dir):
                os.mkdir(fig_dir)

            plt.savefig(f"{fig_dir}W_cell_abundance_q05_{s}.png",
                        bbox_inches='tight')
            plt.close()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]


Epoch 7396/20000:  37%|███▋      | 7395/20000 [21:51<35:35,  5.90it/s, v_num=1, elbo_train=2.13e+7]  