# **MixupVI**

#### Standard imports

In [1]:
import scanpy as sc
import scvi
import anndata as ad
import pandas as pd
import matplotlib.pyplot as plt

sc.set_figure_params(figsize=(4, 4))

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

  self.seed = seed
  self.dl_pin_memory_gpu_training = (
  from .autonotebook import tqdm as notebook_tqdm


# 1. Data preparation and model training (If model and data already saved, SKIP!!)

## 1.1 Loading and preparing data 

#### Cross-tissue immune cell analysis reveals tissue-specific features in humans - Global

Despite their crucial role in health and disease, our knowledge of immune cells within human tissues remains limited. We surveyed the immune compartment of 16 tissues from 12 adult donors by single-cell RNA sequencing and VDJ sequencing generating a dataset of ~360,000 cells. To systematically resolve immune cell heterogeneity across tissues, we developed CellTypist, a machine learning tool for rapid and precise cell type annotation. Using this approach, combined with detailed curation, we determined the tissue distribution of finely phenotyped immune cell types, revealing hitherto unappreciated tissue-specific features and clonal architecture of T and B cells. Our multitissue approach lays the foundation for identifying highly resolved immune cell types by leveraging a common reference dataset, tissue-integrated expression analysis, and antigen receptor sequencing.

```{important}
All scvi-tools models require AnnData objects as input.
```

In [2]:
from benchmark_utils import (
  preprocess_scrna,
  split_dataset
)

In [4]:
SIGNATURE_CHOICE = "crosstissue_general_updated"  # ["laughney", "almudena", "crosstissue_general", "crosstissue_granular_updated"]
CELL_TYPE_GROUP = "updated_granular_groups" 

In [5]:
## CTI 
adata = sc.read("/home/owkin/data/cross-tissue/omics/raw/local.h5ad")

In [6]:
preprocess_scrna(adata=adata,
                 keep_genes=2500,
                 batch_key="assay")



### Batch effects exploration

In [7]:
# PCA
sc.pp.scale(adata, max_value=10)
sc.tl.pca(adata, svd_solver='arpack')

: 

In [None]:
# UMAP
sc.tl.umap(adata)

In [None]:
adata.layers["counts"] = adata.raw.X.copy() 
# copy counts
adata.X = adata.raw.X.copy()

adata.X.shape, adata.raw.X.shape, adata.layers["counts"].shape

#### Signature

In [None]:
from constants import GROUPS

from benchmark_utils import (
  read_almudena_signature,
  map_hgnc_to_ensg,
  perform_nnls,
  compute_correlations,
  create_signature,
  add_cell_types_grouped,
)

In [None]:
signature = create_signature(adata,
                             signature_type=SIGNATURE_CHOICE,
                             group=CELL_TYPE_GROUP)
add_cell_types_grouped(adata)

In [None]:
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=5000,
    subset=True,
    layer="counts",
    flavor="seurat_v3",
    batch_key="assay",
)

Distribution of the batch ids

In [None]:
adata.obs["donor_id"].value_counts(), adata.obs["assay"].value_counts()

scVI is inly trained on the set of most highly variable genes, therefore, some low variance genes that are present in the signature matrix might be filtered out

In [None]:
excluded_genes = set(signature.index) - set(adata.var_names)

len(excluded_genes)

Now it's time to run `setup_anndata()`, which alerts scvi-tools to the locations of various matrices inside the anndata. It's important to run this function with the correct arguments so scvi-tools is notified that your dataset has batches, annotations, etc. For example, if batches are registered with scvi-tools, the subsequent model will correct for batch effects. See the full documentation for details.

In this dataset, there is a "cell_source" categorical covariate, and within each "cell_source", multiple "donors", "gender" and "age_group". There are also two continuous covariates we'd like to correct for: "percent_mito" and "percent_ribo". These covariates can be registered using the `categorical_covariate_keys` argument. If you only have one categorical covariate, you can also use the `batch_key` argument instead.

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    layer="counts",
    categorical_covariate_keys=["assay", "donor_id"],
    # continuous_covariate_keys=["percent_mito", "percent_ribo"],
)

```{warning}
If the adata is modified after running `setup_anndata`, please run `setup_anndata` again, before creating an instance of a model.
```

## 1.2 Creating and training a model

While we highlight the scVI model here, the API is consistent across all scvi-tools models and is inspired by that of [scikit-learn](https://scikit-learn.org/stable/). For a full list of options, see the scvi [documentation](https://scvi-tools.org).

In [None]:
model = scvi.model.SCVI(adata)

# model.view_anndata_setup()

In [None]:
model.train(max_epochs=100)

Save model and anndata

In [None]:
# model.save("dirpath/")
# adata.write("dirpath/filename.h5ad")

```{important}
All scvi-tools models run faster when using a GPU. By default, scvi-tools will use a GPU if one is found to be available. Please see the installation page for more information about installing scvi-tools when a GPU is available.
```

# 2. Load saved Model and Anndata object 

Load anndata

In [None]:
adata = ad.read_h5ad("/home/owkin/deepdeconv/notebooks/data/adata_cti_5000.h5ad")

adata

In [None]:
adata.obs._scvi_batch.value_counts()

#### [Optional]: retrain a model

In [None]:
# train test split
# from sklearn.model_selection import train_test_split

# cell_types_train, cell_types_test = train_test_split(
#     adata.obs_names,
#     test_size=0.5,
#     stratify=adata.obs.cell_types_grouped,
#     random_state=42,
# )
# adata_train = adata[cell_types_train, :]
# adata_test = adata[cell_types_test, :]

In [None]:
# adata_train = adata_train.copy()
scvi.model.SCVI.setup_anndata(
    adata,
    layer="counts",
    categorical_covariate_keys=["assay", "donor_id"],
    # continuous_covariate_keys=["percent_mito", "percent_ribo"],
)
model = scvi.model.SCVI(adata)
model.train(max_epochs=300)
model.save("models/cti_300_epochs/")

#### Load fitted models

In [None]:
import os

dir_path = "/home/owkin/deepdeconv/notebooks/models/"
params = ["100", "200", "400"]

models = {}

for param in params:
    model_name = f"cti_{param}_epochs"
    model = scvi.model.SCVI.load(dir_path=os.path.join(dir_path, model_name),
                                adata=adata,
                                use_gpu=True
                                )
    models[param] = model

Plot losses 

In [None]:
# plt.plot(model.history["elbo_train"])
# plt.plot(model.history["reconstruction_loss_train"])

# 3. Visualizations 

#### Latent space

In [None]:
# Regular scVI
latent = model.get_latent_representation()
adata.obsm["X_scVI"] = latent

latent.shape

### 2D Embedding plots

UMAP on PCA (Without scVI - no batch correction)

In [None]:
# run PCA then generate UMAP plots
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=20)
sc.tl.umap(adata, min_dist=0.3)

In [None]:
sc.pl.umap(
    adata,
    color=["cell_types_grouped", "cell_type"],
    frameon=False,
)
sc.pl.umap(
    adata,
    color=["donor_id", "assay"],
    ncols=2,
    frameon=False,
)

UMAP on scVI latent spae

In [None]:
# run PCA then generate UMAP plots
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.umap(adata, min_dist=0.3)

In [None]:
sc.pl.umap(
    adata,
    color=["cell_types_grouped", "cell_type"],
    frameon=False,
)
sc.pl.umap(
    adata,
    color=["donor_id", "assay"],
    ncols=2,
    frameon=False,
)

The `model.get...()` functions default to using the anndata that was used to initialize the model. It's possible to also query a subset of the anndata, or even use a completely independent anndata object as long as the anndata is organized in an equivalent fashion.

We will use it to compute the **mean gene expression vector for each cell type present in the dataset** .

# 4. Building a signature matrix

### Denoised signature matrix (skip)

In [None]:
# Denoised signature matrix computation

# df_signature_denoised = pd.DataFrame()

# for cell_type in adata.obs["cell_types_grouped"].unique():
#     indices = adata[adata.obs.cell_types_grouped == cell_type].obs.index
#     integer_indices = adata.obs.index.get_indexer(indices)
#     # latent_subset = model.get_latent_representation(adata_subset)
#     denoised = model.get_normalized_expression(indices=integer_indices,
#                                                library_size=1e4)
#     df = denoised.mean(axis=0).to_frame()
#     df.columns = [cell_type]
#     df_signature_denoised = pd.concat([df_signature_denoised, df], axis=1)


# df_signature_denoised.drop(["To remove"], axis=1, inplace=True)

# keep_genes = list(set(signature.index) & set(df_signature_denoised.index))

# signature = signature.loc[keep_genes]
# df_signature_denoised = df_signature_denoised.loc[keep_genes]

# df_signature_denoised.to_csv("/home/owkin/project/Almudena/Output/Crosstiss_Immune/signature_cti_csv_5000.csv")

###### Vanilla signature matrix
# X_norm = sc.pp.normalize_total(adata,
#                                target_sum=1e4,
#                                layer="counts",
#                                inplace=False)['X']

# df_signature = pd.DataFrame()

# for cell_type in adata.obs["cell_type"].unique():
#     indices = adata[adata.obs.cell_type == cell_type].obs.index
#     integer_indices = adata.obs.index.get_indexer(indices)
#     df = pd.DataFrame(X_norm[integer_indices, :].mean(axis=0).T,
#              index=adata.var_names,
#              columns=[cell_type])
#     df_signature = pd.concat([df_signature, df], axis=1) 

Correlation between signature matrices

In [None]:
# import seaborn as sns 
# from scipy.stats import spearmanr, pearsonr
# import numpy as np

# cell_types = list(signature.columns)

# n_cols = 5

# fig, ax = plt.subplots(1, n_cols, figsize=(25, 5))

# for i in range(n_cols):
#     cell_type = cell_types[i]
#     x_sig = signature[cell_type].values
#     x_denoised = df_signature_denoised[cell_type].values
#     # keep_genes = pd.Index(list(set(marker_genes[cell_type]) & set(adata.var_names)))
#     # x_raw = df_signature.loc[keep_genes][cell_type].values
#     # x_denoised = df_signature_denoised.loc[keep_genes][cell_type].values
#     corr = spearmanr(x_sig, x_denoised)[0]
    
#     sns.scatterplot(x=x_sig,
#                     y=x_denoised,
#                     ax=ax[i]
#                     )
#     ax[i].set_xlabel(f"Denoised average ({cell_type})")
#     ax[i].set_ylabel(f"Signature average ({cell_type})")
#     ax[i].set_title(f"Correlation {np.around(corr,3)}")

### Denoised dataset

In [None]:
denoised_train = model.get_normalized_expression(
                                            adata_train,
                                            library_size=1e4)

denoised_train

In [None]:
denoised_test = model.get_normalized_expression(
                                            adata_test,
                                            library_size=1e4)

denoised_test

In [None]:
len(set(denoised_train.index) & set(denoised_test.index))

Save denoised datasets

In [None]:
denoised_train.to_csv("/home/owkin/project/Almudena/Output/Crosstiss_Immune/denoised_train_cti_5000.csv")
denoised_test.to_csv("/home/owkin/project/Almudena/Output/Crosstiss_Immune/denoised_test_cti_5000.csv")

# 5. Latent space linearity sanity checks

In [None]:
import tqdm
from scvi_sanity_checks_utils import sanity_checks_metrics

In [None]:
batch_size = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] #, 32768, 65536, 131072] 

latent_space_metrics = {}

params = [str(x) for x in (100, 200, 300, 400)]

for param in tqdm.tqdm(params[:1]):
    latent_space_metrics[param] = {}
    metrics, errors = sanity_checks_metrics(models[param],
                                            adata,
                                            batch_sizes=batch_size,
                                            n_repeats=100,
                                            use_get_latent=True)
    latent_space_metrics[param]["corr"] = metrics["corr"]
    latent_space_metrics[param]["error"] = errors["corr"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 4))

plt.plot(batch_size, corrs, color="green", linestyle="--", marker="+")
plt.xlabel("Batch size")
plt.ylabel("Pearson correaltion")
plt.xticks(batch_size)
plt.title("Sanity check 0: correlation between sum(encodings)~=encoder(pseudo-bulk)")