# MrVI Quick Start Tutorial

MrVI (Multi-resolution Variational Inference) is a package for analyzing multi-sample single-cell RNA-seq data. This tutorial will guide you through the main features of MrVI.

In [None]:
!pip install --quiet scvi-colab
from scvi_colab import install

install()

In [None]:
import os
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
# from scvi.external import MrVI
from mrvi import MrVI
import seaborn as sns


scvi.settings.seed = 0  # optional: ensures reproducibility
print("Last run with scvi-tools version:", scvi.__version__)
save_dir = tempfile.TemporaryDirectory()

# Preprocessing and model fitting

In [None]:
# adata_path = os.path.join(save_dir.name, "haniffa_tutorial_subset.h5ad")
adata_path = os.path.join("../haniffa_tutorial_subset.h5ad")

adata = sc.read(
    adata_path,
    # backup_url="FILL" # Currently you can find this at s3://largedonor/haniffa_tutorial_subset.h5ad
)
adata.obs.index.name = "cell_name"
sc.pp.highly_variable_genes(adata, n_top_genes=10000, inplace=True, subset=True, flavor="seurat_v3")
adata

Before training, we need to specify which covariates in `obs` should be used as target (`sample_key`) and nuisance variables (`batch_key`). 
In this tutorial, we will use donor IDs as the target variable, and leave the batch variable empty since the data is already subsetted to the Newcastle cohort.

In [None]:
sample_key="patient_id"  # target covariate
# batch_key="Site"  # nuisance variable identifier
MrVI.setup_anndata(adata, sample_key=sample_key)

In [None]:
model = MrVI(adata)
model.train(max_epochs=400)

Once trained, we can plot the ELBO of the model to check if the model has converged.

In [None]:
plt.plot(model.history["elbo_validation"].iloc[5:])
plt.xlabel("Epoch")
plt.ylabel("Validation ELBO")
plt.show()

# Visualize cell embeddings and sample distances

The latent representations of the cells can also be accessed and visualized using the `get_latent_representation` method.
Here, we visualize the latent space in 2D using minimum-distortion embeddings~(MDE).

In [None]:
u = model.get_latent_representation()
# or z = model.get_latent_representation(give_z=True) to get z instead of u
u_mde = scvi.model.utils.mde(u)
adata.obsm["u_mde"] = u_mde
sc.pl.embedding(
    adata, 
    basis="u_mde",
    color=["initial_clustering", "Status"],
    ncols=1,
)

Sample distances can be computed using the `get_local_sample_distances` method, which characterizes local sample relationships for any cell in the dataset.
This method can return cell-specific distances (`keep_cell=True`), as well as averaged distances within cell subpopulations, characterized by the `groupby` argument.
Specifying `keep_cell=False` will ensure that cell-specific distances are not returned, which can reduce the memory footprint of the returned object in the case where many samples are present.

In [None]:
dists = model.get_local_sample_distances(
    keep_cell=False, groupby="initial_clustering", batch_size=32
)
d1 = dists.loc[{"initial_clustering_name": "CD16"}].initial_clustering

The following cell provides useful utility functions to perform hierarchical clustering based on sample distances, as well as to extract sample metadata of interest to visualize the distance matrices

In [None]:
from matplotlib.colors import to_hex
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, optimal_leaf_ordering


def get_sample_colors():
    cmap = sns.color_palette("viridis", as_cmap=True)

    def get_onset_colors(x):
        if x == "Healthy":
            return to_hex(np.array([0.5, 0.5, 0.5, 1.0]))
        else:
            x_ = int(x) / 30.0
            return to_hex(cmap(x_))


    covid_map = {
        "Covid": "red",
        "Healthy": "green",
    }
    sample_info = model.sample_info.set_index("sample_id")
    covid_colors = sample_info.Status.map(covid_map).values
    onset_colors = sample_info.Days_from_onset.map(get_onset_colors)
    colors = pd.DataFrame(
        {
            "covid": covid_colors,
            "onset": onset_colors,
        }
    )
    return colors

def get_dendrogram(dists):
    ds = squareform(dists)
    Z = linkage(ds, method="ward")
    Z = optimal_leaf_ordering(Z, ds)
    return Z

In [None]:
Z = get_dendrogram(d1)
colors = get_sample_colors()

sns.clustermap(
    d1.to_pandas(),
    row_linkage=Z,
    col_linkage=Z,
    xticklabels=False,
    yticklabels=False,
    row_colors=colors,
)

# Differential expression and differential abundance analysis

In [None]:
sample_cov_keys = ["Status"]  # Replace with your sample covariate of interest
de_res = model.differential_expression(
    sample_cov_keys=sample_cov_keys,
)

In [None]:
da_res = model.differential_abundance(sample_cov_keys=sample_cov_keys)
A_log_probs = da_res.Status_log_probs.loc[{"Status": "Covid"}]
B_log_probs = da_res.Status_log_probs.loc[{"Status": "Healthy"}]
A_B_log_prob_ratio = A_log_probs - B_log_probs

In [None]:
adata.obs["DA_covid"] = A_B_log_prob_ratio.values
sc.pl.embedding(
    adata, 
    basis="u_mde",
    color=["initial_clustering", "DA_covid"],
    ncols=1,
    vmin=-1, 
    vmax=1,
)