# Topic Modeling with Amortized LDA

## Load required library

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
from glob import glob
from pathlib import Path
import scrublet as scr
import pickle
import os
import anndata
import matplotlib.pyplot as plt
import scvi

sc.settings.verbosity = 3

### Initial setup for output folder and fig resolution

In [4]:
sc.settings.set_figure_params(dpi=150)

Result_dir = "../results/"
Path(Result_dir).mkdir(parents=True, exist_ok=True)

results_file = Result_dir + 'SCC.h5ad'  # the file that will store the analysis results

## Read Datasets

In [5]:
OBJs = []

## SCC

In [6]:
!ls ../filtered_feature_bc_matrix_h5/*cSCC* 

filtered_feature_bc_matrix_h5/P10_cSCC.h5
filtered_feature_bc_matrix_h5/P1_1_cSCC.h5
filtered_feature_bc_matrix_h5/P1_2_cSCC.h5
filtered_feature_bc_matrix_h5/P2_cSCC.h5
filtered_feature_bc_matrix_h5/P3_1_cSCC.h5
filtered_feature_bc_matrix_h5/P3_2_cSCC.h5
filtered_feature_bc_matrix_h5/P4_cSCC.h5
filtered_feature_bc_matrix_h5/P5_cSCC.h5
filtered_feature_bc_matrix_h5/P6_cSCC.h5
filtered_feature_bc_matrix_h5/P7_cSCC.h5
filtered_feature_bc_matrix_h5/P8_1_cSCC.h5
filtered_feature_bc_matrix_h5/P8_2_cSCC.h5
filtered_feature_bc_matrix_h5/P9_cSCC.h5


In [7]:
data_SCC_P8_1 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P8_1_cSCC.h5")
data_SCC_P8_1.var_names_make_unique()
data_SCC_P8_1.obs['type'] = 'SCC'
data_SCC_P8_1.obs['sample'] = 'SCC_P8_1'

data_SCC_P8_2 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P8_2_cSCC.h5")
data_SCC_P8_2.var_names_make_unique()
data_SCC_P8_2.obs['type'] = 'SCC'
data_SCC_P8_2.obs['sample'] = 'SCC_P8_2'

data_SCC_P7 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P7_cSCC.h5")
data_SCC_P7.var_names_make_unique()
data_SCC_P7.obs['type'] = 'SCC'
data_SCC_P7.obs['sample'] = 'SCC_P7'

data_SCC_P3_1 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P3_1_cSCC.h5")
data_SCC_P3_1.var_names_make_unique()
data_SCC_P3_1.obs['type'] = 'SCC'
data_SCC_P3_1.obs['sample'] = 'SCC_P3_1'

data_SCC_P2 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P2_cSCC.h5")
data_SCC_P2.var_names_make_unique()
data_SCC_P2.obs['type'] = 'SCC'
data_SCC_P2.obs['sample'] = 'SCC_P2'

data_SCC_P5 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P5_cSCC.h5")
data_SCC_P5.var_names_make_unique()
data_SCC_P5.obs['type'] = 'SCC'
data_SCC_P5.obs['sample'] = 'SCC_P5'

data_SCC_P3_2 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P3_2_cSCC.h5")
data_SCC_P3_2.var_names_make_unique()
data_SCC_P3_2.obs['type'] = 'SCC'
data_SCC_P3_2.obs['sample'] = 'SCC_P3_2'

data_SCC_P4 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P4_cSCC.h5")
data_SCC_P4.var_names_make_unique()
data_SCC_P4.obs['type'] = 'SCC'
data_SCC_P4.obs['sample'] = 'SCC_P4'

data_SCC_P1_2 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P1_2_cSCC.h5")
data_SCC_P1_2.var_names_make_unique()
data_SCC_P1_2.obs['type'] = 'SCC'
data_SCC_P1_2.obs['sample'] = 'SCC_P1_2'

data_SCC_P9 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P9_cSCC.h5")
data_SCC_P9.var_names_make_unique()
data_SCC_P9.obs['type'] = 'SCC'
data_SCC_P9.obs['sample'] = 'SCC_P9'

data_SCC_P10 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P10_cSCC.h5")
data_SCC_P10.var_names_make_unique()
data_SCC_P10.obs['type'] = 'SCC'
data_SCC_P10.obs['sample'] = 'SCC_P10'

data_SCC_P1_1 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P1_1_cSCC.h5")
data_SCC_P1_1.var_names_make_unique()
data_SCC_P1_1.obs['type'] = 'SCC'
data_SCC_P1_1.obs['sample'] = 'SCC_P1_1'

data_SCC_P6 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P6_cSCC.h5")
data_SCC_P6.var_names_make_unique()
data_SCC_P6.obs['type'] = 'SCC'
data_SCC_P6.obs['sample'] = 'SCC_P6'

reading filtered_feature_bc_matrix_h5/P8_1_cSCC.h5
 (0:00:00)
reading filtered_feature_bc_matrix_h5/P8_2_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P7_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P3_1_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P2_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P5_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P3_2_cSCC.h5
 (0:00:00)
reading filtered_feature_bc_matrix_h5/P4_cSCC.h5


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


 (0:00:01)


  utils.warn_names_duplicates("var")


reading filtered_feature_bc_matrix_h5/P1_2_cSCC.h5
 (0:00:00)
reading filtered_feature_bc_matrix_h5/P9_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P10_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P1_1_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P6_cSCC.h5


  utils.warn_names_duplicates("var")


 (0:00:00)


  utils.warn_names_duplicates("var")


## Normal

In [8]:
data_nrl_P4 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P4.h5")
data_nrl_P4.var_names_make_unique()
data_nrl_P4.obs['type'] = 'Normal'
data_nrl_P4.obs['sample'] = 'nrl_P4'

data_nrl_P1_1 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P1_1.h5")
data_nrl_P1_1.var_names_make_unique()
data_nrl_P1_1.obs['type'] = 'Normal'
data_nrl_P1_1.obs['sample'] = 'nrl_P1_1'

data_nrl_P10 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P10.h5")
data_nrl_P10.var_names_make_unique()
data_nrl_P10.obs['type'] = 'Normal'
data_nrl_P10.obs['sample'] = 'nrl_P10'

data_nrl_P3 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P3.h5")
data_nrl_P3.var_names_make_unique()
data_nrl_P3.obs['type'] = 'Normal'
data_nrl_P3.obs['sample'] = 'nrl_P3'

data_nrl_P2 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P2.h5")
data_nrl_P2.var_names_make_unique()
data_nrl_P2.obs['type'] = 'Normal'
data_nrl_P2.obs['sample'] = 'nrl_P2'

data_nrl_P9 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P9.h5")
data_nrl_P9.var_names_make_unique()
data_nrl_P9.obs['type'] = 'Normal'
data_nrl_P9.obs['sample'] = 'nrl_P9'

data_nrl_P8 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P8.h5")
data_nrl_P8.var_names_make_unique()
data_nrl_P8.obs['type'] = 'Normal'
data_nrl_P8.obs['sample'] = 'nrl_P8'

data_nrl_P7 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P7.h5")
data_nrl_P7.var_names_make_unique()
data_nrl_P7.obs['type'] = 'Normal'
data_nrl_P7.obs['sample'] = 'nrl_P7'

data_nrl_P6 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P6.h5")
data_nrl_P6.var_names_make_unique()
data_nrl_P6.obs['type'] = 'Normal'
data_nrl_P6.obs['sample'] = 'nrl_P6'

data_nrl_P5 = sc.read_10x_h5("../filtered_feature_bc_matrix_h5/P5.h5")
data_nrl_P5.var_names_make_unique()
data_nrl_P5.obs['type'] = 'Normal'
data_nrl_P5.obs['sample'] = 'nrl_P5'

reading filtered_feature_bc_matrix_h5/P4.h5
 (0:00:00)
reading filtered_feature_bc_matrix_h5/P1_1.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P10.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P3.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P2.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P9.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P8.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P7.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P6.h5


  utils.warn_names_duplicates("var")


 (0:00:00)
reading filtered_feature_bc_matrix_h5/P5.h5


  utils.warn_names_duplicates("var")


 (0:00:00)


  utils.warn_names_duplicates("var")


## Merge all into one object and delete used objectes.

In [9]:
%%time

adata = data_SCC_P1_1.concatenate(data_SCC_P10, data_SCC_P1_2, data_SCC_P2, data_SCC_P3_1, data_SCC_P3_2, 
                                  data_SCC_P4, data_SCC_P5, data_SCC_P6, data_SCC_P7, data_SCC_P8_1, 
                                  data_SCC_P8_2, data_SCC_P9, data_nrl_P10, data_nrl_P1_1, data_nrl_P2, 
                                  data_nrl_P3, data_nrl_P4, data_nrl_P5, data_nrl_P6, data_nrl_P7, 
                                  data_nrl_P8, data_nrl_P9)

# and delete individual datasets to save space
del(data_SCC_P10, data_SCC_P1_1, data_SCC_P1_2, data_SCC_P2, 
    data_SCC_P3_1, data_SCC_P3_2, data_SCC_P4, data_SCC_P5, 
    data_SCC_P6, data_SCC_P7, data_SCC_P8_1, data_SCC_P8_2, 
    data_SCC_P9, data_nrl_P10, data_nrl_P1_1, data_nrl_P2, 
    data_nrl_P3, data_nrl_P4, data_nrl_P5, data_nrl_P6, 
    data_nrl_P7, data_nrl_P8, data_nrl_P9)


  [AnnData(sparse.csr_matrix(a.shape), obs=a.obs) for a in all_adatas],


CPU times: user 1.91 s, sys: 314 ms, total: 2.23 s
Wall time: 2.22 s


### Save data to disk temporarily

In [10]:
adata.write(results_file)

In [11]:
print(adata.obs['sample'].value_counts())

adata

SCC_P4      12474
nrl_P9       7595
nrl_P6       7066
nrl_P2       6205
SCC_P7       5292
SCC_P2       5123
SCC_P10      5053
SCC_P8_1     4459
SCC_P6       4453
nrl_P10      4324
nrl_P1_1     3584
SCC_P9       3261
nrl_P7       3105
SCC_P5       2323
nrl_P3       1553
nrl_P5       1551
SCC_P8_2     1531
SCC_P1_1     1194
SCC_P1_2      925
nrl_P8        767
SCC_P3_2      366
nrl_P4        290
SCC_P3_1      127
Name: sample, dtype: int64


AnnData object with n_obs × n_vars = 82621 × 36601
    obs: 'type', 'sample', 'batch'
    var: 'gene_ids', 'feature_types', 'genome'

## Calculate QC

In [12]:
sc.set_figure_params(figsize=(4, 4))

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

Global seed set to 0


### Read data from disk

In [13]:
save_path = "Results"
# adata = sc.read(os.path.join(save_path, "pbmc_10k_protein_v3.h5ad"), backup_url="https://github.com/YosefLab/scVI-data/raw/master/pbmc_10k_protein_v3.h5ad?raw=true")
# adata = scvi.data.read_h5ad("P9_filtered_feature_bc_matrix.h5")
adata = sc.read("Results/SCC.h5ad")

adata

AnnData object with n_obs × n_vars = 82621 × 36601
    obs: 'type', 'sample', 'batch'
    var: 'gene_ids', 'feature_types', 'genome'

In [14]:
print(adata.obs['sample'].value_counts())

adata

SCC_P4      12474
nrl_P9       7595
nrl_P6       7066
nrl_P2       6205
SCC_P7       5292
SCC_P2       5123
SCC_P10      5053
SCC_P8_1     4459
SCC_P6       4453
nrl_P10      4324
nrl_P1_1     3584
SCC_P9       3261
nrl_P7       3105
SCC_P5       2323
nrl_P3       1553
nrl_P5       1551
SCC_P8_2     1531
SCC_P1_1     1194
SCC_P1_2      925
nrl_P8        767
SCC_P3_2      366
nrl_P4        290
SCC_P3_1      127
Name: sample, dtype: int64


AnnData object with n_obs × n_vars = 82621 × 36601
    obs: 'type', 'sample', 'batch'
    var: 'gene_ids', 'feature_types', 'genome'

## Filter low quality data

In [16]:
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)

filtered out 1997 cells that have less than 200 genes expressed
filtered out 8815 genes that are detected in less than 3 cells


In [17]:
adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

In [18]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'],
             jitter=0.4, multi_panel=True)

In [19]:
sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt')
sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts')

In [20]:
adata = adata[adata.obs.n_genes_by_counts < 2500, :]
adata = adata[adata.obs.pct_counts_mt < 5, :]

In [21]:
adata.layers["counts"] = adata.X.copy()

## Create the model for training

In [22]:
n_topics = 10

scvi.model.AmortizedLDA.setup_anndata(adata, layer = "counts")
model = scvi.model.AmortizedLDA(adata, n_topics = n_topics)

In [23]:
scvi.model.AmortizedLDA(adata, n_topics = n_topics)



## Train the model 

In [24]:
model.train()

## Save the trained model (Optional)

In [25]:
model.save("SCC_Topic_model/")

In [26]:
model = model.load("SCC_Topic_model/", adata=adata)#, use_gpu=True)

[34mINFO    [0m File SCC_Topic_model/model.pt already downloaded                                    


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 521/521: 100%|██████████| 521/521 [46:35<00:00,  5.37s/it, v_num=1, elbo_train=1.06e+9]


In [27]:
adata

AnnData object with n_obs × n_vars = 38379 × 27786
    obs: 'type', 'sample', 'batch', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    layers: 'counts'

## Visualizing learned topics

In [31]:
topic_prop = model.get_latent_representation()
topic_prop.head()

Unnamed: 0,topic_0,topic_1,topic_2,topic_3,topic_4,topic_5,topic_6,topic_7,topic_8,topic_9
AACGTTGAGAAGGACA-1-0,7.6e-05,0.000237,3.9e-05,0.000122,0.000207,0.645719,7.6e-05,4.4e-05,0.1525,0.200979
AAGACCTAGCGATTCT-1-0,0.005141,0.001239,2.4e-05,0.339639,0.000676,0.057522,0.075478,0.013058,0.506585,0.000638
AAGGCAGTCGGCCGAT-1-0,0.00016,1.5e-05,2e-06,2.3e-05,9.3e-05,0.998916,7e-05,5e-06,0.000383,0.000333
ACGAGCCAGACCTAGG-1-0,0.000497,0.069345,0.00012,0.009453,0.008865,0.001548,0.025615,0.514983,0.368926,0.000647
ACGATACCACTTCTGC-1-0,0.012737,0.002933,0.00011,0.005164,0.002966,0.923207,0.000273,0.000866,0.050827,0.000917


### Save topic proportions in obsm and obs columns.

In [32]:
adata.obsm["X_LDA"] = topic_prop
for i in range(n_topics):
    adata.obs[f"LDA_topic_{i}"] = topic_prop[[f"topic_{i}"]]

## Find top genes per topic

In [36]:
feature_by_topic = model.get_feature_by_topic()
feature_by_topic.head()
feature_by_topic.to_csv("Results/Feature_by_topic.csv", index=False)

In [39]:
rank_by_topic = pd.DataFrame()
for i in range(n_topics):
    topic_name = f"topic_{i}"
    topic = feature_by_topic[topic_name].sort_values(ascending=False)
    rank_by_topic[topic_name] = topic.index
    rank_by_topic[f"{topic_name}_prop"] = topic.values

## Save rank by topic data

In [40]:
rank_by_topic.head()
rank_by_topic.to_csv("../results/Rank_by_topic.csv", index=False)

# END