In [1]:
"""
Script to generate a regression model for cell2location from raw sn-seq data
"""

import anndata
import pandas as pd
import scanpy as sc
from numpy.random import default_rng
import numpy as np
import os
from pathlib import Path

# this line forces theano to use the GPU and should go before importing cell2location
os.environ["THEANO_FLAGS"] = 'device=cuda0,floatX=float32,force_device=True'

import cell2location

from cell2location.utils.filtering import filter_genes
from cell2location.models import RegressionModel

import argparse

Global seed set to 0
  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [13]:
test = sc.read_h5ad("/home/philipp/Work/VisiumMS/data/cellbender_out/MS466/cell_bender_matrix_filtered_qc_annotated.h5ad")
test.X.todense()

In [17]:
test = sc.read_h5ad("/home/philipp/Work/VisiumMS/data/cellbender_out/MS466/cell_bender_matrix_filtered_qc.h5")
test.X.todense()

matrix([[0, 1, 4, ..., 0, 0, 0],
        [0, 0, 1, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint32)

In [20]:
test = sc.read_10x_h5("/home/philipp/Work/VisiumMS/data/cellbender_out/MS466/cell_bender_matrix_filtered.h5")
test.X.todense()

  utils.warn_names_duplicates("var")


matrix([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 1, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint32)

In [2]:
# TODO: harcoded config
perc_cells = 0.4
sample_id = "sample_id"
label_name = "cell_type"
labels_to_remove = ["unannotated"]

In [3]:

#current_folder = Path(__file__).parent
current_folder = globals()['_dh'][0]
output_dir = current_folder / ".." / ".." / "data" / "cellbender_out"

samples = [sample for sample in os.listdir(output_dir) if not sample.startswith(".")]

adata_objects = {sample: sc.read_h5ad(output_dir / sample / "cell_bender_matrix_filtered_qc_annotated.h5ad") for sample in samples}

adata_raw = sc.concat(list(adata_objects.values()), join="outer", label=sample_id, keys=list(adata_objects.keys()), index_unique="_")
adata_raw.var_names_make_unique()

In [4]:
adata_raw = adata_raw[~adata_raw.obs[label_name].isin(labels_to_remove), :]

In [5]:
sample_meta = pd.read_excel(current_folder / ".." / ".." / "data" / "Metadata_all.xlsx")
# we will make one model per condition
sample_meta.Condition

0          MS
1          MS
2          MS
3          MS
4          MS
5          MS
6          MS
7          MS
8          MS
9          MS
10         MS
11         MS
12         MS
13         MS
14         MS
15         MS
16         MS
17         MS
18         MS
19         MS
20         MS
21    Control
22    Control
23    Control
24    Control
25    Control
26    Control
27    Control
28    Control
Name: Condition, dtype: object

In [6]:
#ms_samples = sample_meta["Brain bank ID"][sample_meta.Condition=="MS"]
#ctrl_samles = sample_meta["Brain bank ID"][sample_meta.Condition=="Control"]

# ms samples are all sample ids in adata_raw that start with MS
ms_samples = [sample for sample in adata_raw.obs[sample_id].unique() if sample.startswith("MS")]

# all the other samples are control
ctrl_samles = [sample for sample in adata_raw.obs[sample_id].unique() if sample not in ms_samples]

In [7]:
ms_samples

['MS411', 'MS466', 'MS497T', 'MS377I', 'MS377T', 'MS549T', 'MS497I', 'MS549H']

In [8]:
ctrl_samles

['CO74', 'CO85', 'CO40']

In [9]:
ms_adata_raw = adata_raw[adata_raw.obs[sample_id].isin(ms_samples), :].copy()
ctrl_adata_raw = adata_raw[adata_raw.obs[sample_id].isin(ctrl_samles), :].copy()

In [10]:
adata_raw.obs[sample_id]

TCCACCACAGCTTCCT-1_MS411      MS411
AGCTTCCAGGTGTGAC-1_MS411      MS411
CATACTTTCATTACGG-1_MS411      MS411
CACGAATGTCTCTCCA-1_MS411      MS411
AACACACAGTAGAATC-1_MS411      MS411
                              ...  
TCCGTGTGTCAGTCTA-1_MS549H    MS549H
TGGGCTGTCCATATGG-1_MS549H    MS549H
GAGTGAGTCGTAGGAG-1_MS549H    MS549H
TGAGGGATCCGTCAAA-1_MS549H    MS549H
CGTGCTTAGGCATCGA-1_MS549H    MS549H
Name: sample_id, Length: 75555, dtype: category
Categories (11, object): ['MS411', 'MS466', 'MS497T', 'CO74', ..., 'MS497I', 'CO85', 'CO40', 'MS549H']

In [11]:
# not that these are the corrected counts from cellbender
ms_adata_raw.X[0:6, 0:6].todense()

matrix([[0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1]], dtype=uint32)

In [None]:
# coarse to integers
ms_adata_raw.X = np.round(ms_adata_raw.X).astype(int)
ctrl_adata_raw.X = np.round(ctrl_adata_raw.X).astype(int)
ms_adata_raw.X[0:12, 0:12].todense()

In [None]:
adata_raw.obs[label_name].unique()

In [None]:
# subset for testing
rng = default_rng(seed=42)

t_cell_ids = []

for cell_type in adata_raw.obs[label_name].unique():
    
    # Select cells from a cell type
    msk = adata_raw.obs[label_name] == cell_type
    cell_ids = adata_raw.obs.index[msk]
    
    n_cells = int(np.ceil(perc_cells * len(cell_ids)))
    
    cell_ids = rng.choice(cell_ids, size=n_cells, replace=False)
    t_cell_ids.extend(cell_ids)
    
adata_raw = adata_raw[t_cell_ids, :]

In [None]:
adata_raw

In [None]:
selected = filter_genes(adata_raw, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_raw = adata_raw[:, selected].copy()

In [None]:
adata_raw

In [None]:
adata_raw.obs.cell_type.value_counts()

In [None]:
# get cell types for which the count is below 10 in adata_raw.obs.cell_type.value_counts()
ct_to_remove = adata_raw.obs.cell_type.value_counts().index[adata_raw.obs.cell_type.value_counts() < 10]
print(ct_to_remove)

# subset the adata object
adata_raw = adata_raw[~adata_raw.obs.cell_type.isin(ct_to_remove), :]

In [None]:
cell2location.models.RegressionModel.setup_anndata(adata=adata_raw,
                              # 10X reaction / sample / batch
                              batch_key=sample_id,
                              # cell type, covariate used for constructing signatures
                              labels_key=label_name
)

In [None]:
mod = RegressionModel(adata_raw)
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=250, batch_size=2500, train_size=1, lr=0.002, use_gpu=True)

In [None]:
import matplotlib.pyplot as plt
mod.plot_history(20)