# scNetNB - COVID-19 integration with PBMC and HCL atlases

In [None]:
import os
os.chdir("../../../scnet/")
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import scnet as sn
import scanpy as sc

In [None]:
sc.settings.set_figure_params(dpi=150)

In [None]:
study_key = "study"
cell_type_key = "celltype"

In [None]:
query_studies = ['COVID-19', 'PBMC 68K']

## Load data

In [None]:
adata = sn.data.read("/home/mohsen/data/covid_integrated_normalized_hvg.h5ad")
adata

### if adata.X is read counts, you have to run the following cell.

In [None]:
# adata = sn.data.normalize_hvg(adata, batch_key=study_key, target_sum=1e4, n_top_genes=2000)

In [None]:
reference_adata = adata[~adata.obs[study_key].isin(query_studies)]
reference_adata

In [None]:
query_adata = adata[adata.obs[study_key].isin(query_studies)]
query_adata

In [None]:
n_studies = len(reference_adata.obs[study_key].unique().tolist())
n_studies

## create scNetNB object

In [None]:
network = sn.models.scNetNB(task_name='COVID_reference',
                            x_dimension=reference_adata.shape[1], 
                            z_dimension=20,
                            architecture=[128, 32],
                            n_conditions=n_studies,
                            use_batchnorm=False,
                            alpha=0.001,
                            scale_factor=1.0,
                            clip_value=3,
                            loss_fn='mse',
                            model_path="./models/scNetNB/",
                            dropout_rate=0.05,
                            output_activation='relu')

## Train scNetNB

In [None]:
network.train(reference_adata,
              train_size=0.8, 
              condition_key=study_key,
              n_epochs=300,
              batch_size=128, 
              save=True,
              retrain=True)

In [None]:
plot_adata = sc.pp.subsample(reference_adata, n_obs=50000, copy=True)

### This will automatically return z latent since we don't have mmd regularization

In [None]:
latent_adata = network.get_latent(plot_adata, study_key)
latent_adata

In [None]:
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)

In [None]:
sc.pl.umap(latent_adata, color=[study_key, cell_type_key], frameon=False)

In [None]:
new_network = sn.operate(network, 
                         new_task_name='COVID_query',
                         new_conditions=query_studies,
                         version='scNet',
                         )

In [None]:
new_network.train(query_adata,
                  train_size=0.8, 
                  condition_key=study_key,
                  n_epochs=100,
                  batch_size=128, 
                  save=True,
                  retrain=True)

In [None]:
plot_adata = sc.pp.subsample(reference_adata, n_obs=50000, copy=True)

In [None]:
latent_adata = network.get_latent(plot_adata, study_key)
latent_adata

In [None]:
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)

In [None]:
sc.pl.umap(latent_adata, color=[study_key, cell_type_key], frameon=False)