In [27]:

import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from pathlib import Path

# this line forces theano to use the GPU and should go before importing cell2location
os.environ["THEANO_FLAGS"] = 'device=cuda0,floatX=float32,force_device=True'

import cell2location

from cell2location.utils.filtering import filter_genes
from cell2location.models import RegressionModel

# TODO: harcoded config
sample_id = "sample_id"
label_name = "cell_type"
labels_to_remove = ["unannotated"]

current_folder = globals()['_dh'][0]
#current_folder = Path(__file__).parent
output_dir = current_folder / ".." / ".." / "data" / "cellbender_out"
model_out = current_folder / ".." / ".." / "data" / "c2l_models"
model_out.mkdir(parents=True, exist_ok=True)

samples = [sample for sample in os.listdir(output_dir) if not sample.startswith(".")]

# load the cellbender output
adata_objects = {sample: sc.read_h5ad(output_dir / sample / "cell_bender_matrix_filtered_qc_annotated.h5ad") for sample in samples}
adata_raw = sc.concat(list(adata_objects.values()), join="outer", label=sample_id, keys=list(adata_objects.keys()), index_unique="_")
adata_raw.var_names_make_unique()

# remove unannotated cells
adata_raw = adata_raw[~adata_raw.obs[label_name].isin(labels_to_remove), :]

# split MS and control samples
sample_meta = pd.read_excel(current_folder / ".." / ".." / "data" / "Metadata_all.xlsx", sheet_name="snRNA-seq")
ms_samples = sample_meta.sample_id[sample_meta.Condition=="MS"]
ctrl_samples = sample_meta.sample_id[sample_meta.Condition=="Control"]

missing_ms_samples = ms_samples[~np.isin(ms_samples, samples)]
ms_samples = ms_samples[np.isin(ms_samples, samples)]
missing_ctrl_samples = ctrl_samples[~np.isin(ctrl_samples, samples)]
ctrl_samples = ctrl_samples[np.isin(ctrl_samples, samples)]
if len(missing_ms_samples) > 0:
    print(f"Missing MS samples:\n{missing_ms_samples}")
if len(missing_ctrl_samples) > 0:
    print(f"Missing control samples:\n{missing_ctrl_samples}")

# create the anndata objects
ms_adata_raw = adata_raw[adata_raw.obs[sample_id].isin(ms_samples), :].copy()
print(f"MS dataset:\n{ms_adata_raw}")
print(f"MS samples:\n{ms_adata_raw.obs[sample_id].unique()}")

ctrl_adata_raw = adata_raw[adata_raw.obs[sample_id].isin(ctrl_samples), :].copy()
print(f"Control dataset:\n{ctrl_adata_raw}")
print(f"Control samples:\n{ctrl_adata_raw.obs[sample_id].unique()}")

all_adata_raw = adata_raw
print(f"All dataset:\n{all_adata_raw}")
print(f"All samples:\n{all_adata_raw.obs[sample_id].unique()}")

Missing MS samples: 0      MS197
1      MS229
2     MS371N
3     MS377N
12     MS586
Name: sample_id, dtype: object
Missing control samples: 13    CO37
15    CO45
16    CO41
Name: sample_id, dtype: object
MS dataset:
AnnData object with n_obs × n_vars = 59090 × 31226
    obs: 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'doublet_score', 'predicted_doublet', 'diss_score', 'cell_type', 'cell_type_forced', 'leiden', 'leiden_forced', 'sample_id'
MS samples:
['MS411', 'MS466', 'MS497T', 'MS377I', 'MS377T', 'MS549T', 'MS497I', 'MS549H']
Categories (8, object): ['MS411', 'MS466', 'MS497T', 'MS377I', 'MS377T', 'MS549T', 'MS497I', 'MS549H']
Control dataset:
AnnData object with n_obs × n_vars = 16383 × 31226
    obs: 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'doublet_score', 'predicted_doublet', 'diss_score', 'cell_type', 'cell_type_forced', 'leiden', 'leiden_forced', 'sample_id'
Control samples:
['CO74', 'CO85', 'C

In [26]:
samples

['MS411',
 'MS466',
 'MS497T',
 'CO74',
 'MS377I',
 'MS377T',
 'MS549T',
 'MS497I',
 'CO85',
 'CO40',
 'MS549H']

In [25]:
missing_ctrl_samples

Series([], Name: sample_id, dtype: object)

In [22]:
missing_ms_samples

array([], dtype=object)

In [15]:
sample_meta = pd.read_excel(current_folder / ".." / ".." / "data" / "Metadata_all.xlsx", sheet_name="snRNA-seq")
sample_meta

Unnamed: 0,patient_id,sample_id,Condition,lesion_type,Age,Sex,RIN,Batch,visium,snRNA-seq
0,MS197 P2D3,MS197,MS,CA,52,F,9.0,4,True,True
1,MS229 P2C2,MS229,MS,CA,53,M,7.0,4,True,True
2,MS371 A3D6,MS371N,MS,A,40,M,7.6,4,True,True
3,MS377 A2D2,MS377N,MS,CA,50,F,8.9,4,True,True
4,MS377 A2D4,MS377I,MS,CA,50,F,6.5,1,True,True
5,MS377 A2D4,MS377T,MS,CA,50,F,6.5,1,True,True
6,MS411 A2A2,MS411,MS,CA,61,M,5.9,1,True,True
7,MS466 A1D6,MS466,MS,CI,65,F,6.5,1,True,True
8,MS497 A3C2,MS497I,MS,CI,60,F,6.1,3,True,True
9,MS497 A3C2,MS497T,MS,CI,60,F,6.1,2,True,True


In [17]:
ms_adata_raw.obs[sample_id].unique()

['MS411', 'MS466', 'MS497T', 'MS377I', 'MS377T', 'MS549T', 'MS497I', 'MS549H']
Categories (8, object): ['MS411', 'MS466', 'MS497T', 'MS377I', 'MS377T', 'MS549T', 'MS497I', 'MS549H']

In [18]:
ms_samples = sample_meta.sample_id[sample_meta.Condition=="MS"]
ctrl_samples = sample_meta.sample_id[sample_meta.Condition=="Control"]

0      MS197
1      MS229
2     MS371N
3     MS377N
12     MS586
Name: sample_id, dtype: object

In [None]:
# Run one model for MS and for healthy controls
for adata, identifier in zip([ms_adata_raw, ctrl_adata_raw], ["MS", "Control"]):

    tmp_out = model_out / (identifier + "_reg_model")
    tmp_out.mkdir(parents=True, exist_ok=True)
    print(f"Running regression model for {identifier}, saving in {tmp_out}")

    # Filter by cell2loc thresholds
    selected = filter_genes(adata, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
    adata = adata[:, selected].copy()

    # use integer encdoing for sample and celltype covariates (scvi utility)
    cell2location.models.RegressionModel.setup_anndata(adata=adata,
                                                       # 10X reaction / sample / batch
                                                       batch_key=sample_id,
                                                       # cell type, covariate used for constructing signatures
                                                       labels_key=label_name
    )

    # Run regression model
    # See https://github.com/BayraktarLab/cell2location/blob/a583a836b3a932ac6b4de54edd56b8dcf235245a/cell2location/models/reference/_reference_module.py#L13
    mod = RegressionModel(adata)
    mod.view_anndata_setup()

    # Training 
    mod.train(max_epochs=250, batch_size=2500, train_size=1, lr=0.002, use_gpu=True)

    # Save training plot
    fig, ax = plt.subplots(1,1, facecolor='white')
    mod.plot_history(20, ax=ax)
    fig.savefig(tmp_out / "training_plot.png", dpi=300, bbox_inches='tight')

    # In this section, we export the estimated cell abundance (summary of the posterior distribution).
    adata = mod.export_posterior(
        adata, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
    )

    # export estimated expression in each cluster
    if 'means_per_cluster_mu_fg' in adata.varm.keys():
        inf_aver = adata.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                        for i in adata.uns['mod']['factor_names']]].copy()
    else:
        inf_aver = adata.var[[f'means_per_cluster_mu_fg_{i}'
                                        for i in adata_raw.uns['mod']['factor_names']]].copy()
    inf_aver.columns = adata.uns['mod']['factor_names']

    inf_aver.to_csv(tmp_out / "inf_aver.csv")
