# Integrative analysis of single-cell multiomics data using deep learning

**Filled notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.ipynb)   
**Recording:** 
[![YouTubb](https://img.shields.io/static/v1?logo=youtube&label=&message=Youtube&color=red)](https://youtu.be/ELEqNwv9vkE)   
**Author:** Yuan Tian [![Connect](https://img.shields.io/static/v1?label=&logo=linkedin&message=Connect&color=blue)](https://www.linkedin.com/in/ytiancompbio) 

In this tutorial, we will take a closer look at autoencoders (AE). Autoencoders are trained on encoding input data such as images into a smaller feature vector, and afterward, reconstruct it by a second neural network, called a decoder. The feature vector is called the “bottleneck” of the network as we aim to compress the input data into a smaller amount of features. This property is useful in many applications, in particular in compressing data or comparing images on a metric beyond pixel-level comparisons. Besides learning about the autoencoder framework, we will also see the “deconvolution” (or transposed convolution) operator in action for scaling up feature maps in height and width. Such deconvolution networks are necessary wherever we start from a small feature vector and need to output an image of full size (e.g. in VAE, GANs, or super-resolution applications).

In [None]:
# Standard libraries
import pandas as pd
import numpy as np
import urllib.request
from pathlib import Path
from urllib.error import HTTPError
from tqdm.notebook import tqdm 
from sklearn import preprocessing

# Pytorch and Pytorch Lightning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader, random_split
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Plotting
import umap
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns

# Tensorboard extension (for visualization purposes later)
from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = Path("data")
if not DATASET_PATH.exists():
    DATASET_PATH.mkdir()
    
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = Path("saved_models")
if not CHECKPOINT_PATH.exists():
    CHECKPOINT_PATH.mkdir()

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

## Single-cell multiomics and CITE-seq

<ul>
    <li>Single-cell sequencing reveals cellular heterogeneity that is masked by bulk sequencing methods.</li>
    <li>CITE-seq simultaneously measures gene expression and surface protein at a single-cell level.</li>
</ul>

<figure>
    <center><img src="imgs/citeseq.jpg"/></center>
    <center><figcaption>Image source: 10x Genomics</figcaption></center>
</figure>

## Datasets and Dataloaders

In [None]:
# URL for downloading data
data_url = "https://raw.githubusercontent.com/naity/citeseq_autoencoder/master/data/"

# Files to download
data_files = ["rna_scale.csv.gz", "protein_scale.csv.gz", "metadata.csv.gz"]

# Download datafile if necessary
for file_name in data_files:
    file_path = Path(DATASET_PATH/file_name)
    if not file_path.exists():
        file_url = data_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try downloading the file from the Google Drive folder\n", e)

<figure>
    <center><img src="imgs/dataset.png"/></center>
    <center><figcaption><b>CITE-seq dataset overview</b></figcaption></center>
</figure>

In [None]:
rna = pd.read_csv(DATASET_PATH/"rna_scale.csv.gz", index_col=0).T
pro = pd.read_csv(DATASET_PATH/"protein_scale.csv.gz", index_col=0).T

ncells = rna.shape[0]
nfeatures_rna = rna.shape[1]
nfeatures_pro = pro.shape[1]

print("Number of cells:", ncells)
print("Number of geres:", nfeatures_rna)
print("Number of proteins:", nfeatures_pro)

In [None]:
# concat rna and pro
print("RNA and protein cell barcodes match:", all(rna.index == pro.index))
citeseq = pd.concat([rna, pro], axis=1)
citeseq.head()

In [None]:
# annotations
metadata = pd.read_csv(DATASET_PATH/"metadata.csv.gz", index_col=0)
metadata.head()

In [None]:
# separate CD4 and CD8 in l1
metadata["celltype.l1.5"] = metadata["celltype.l1"].values
metadata.loc[metadata["celltype.l2"].str.startswith("CD4"), "celltype.l1.5"] = "CD4 T"
metadata.loc[metadata["celltype.l2"].str.startswith("CD8"), "celltype.l1.5"] = "CD8 T"
metadata.loc[metadata["celltype.l2"]=="Treg", "celltype.l1.5"] = "CD4 T"
metadata.loc[metadata["celltype.l2"]=="MAIT", "celltype.l1.5"] = "MAIT"
metadata.loc[metadata["celltype.l2"]=="gdT", "celltype.l1.5"] = "gdT"
print("CITE-seq data and metadata cell barcodes match:", all(citeseq.index == pro.index))

# convert to categorical
le = preprocessing.LabelEncoder()
labels = le.fit_transform(metadata["celltype.l1.5"])

In [None]:
class TabularDataset(Dataset):
    """Custome dataset for tabular data"""
    def __init__(self, df: pd.DataFrame, labels: np.ndarray):
        self.data = torch.tensor(df.to_numpy(), dtype=torch.float)
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        return x, y

In [None]:
dataset = TabularDataset(citeseq, labels)

# train, validation, and test split
train_size = int(ncells*0.7)
val_size = int(ncells*0.15)
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, ncells-train_size-val_size],
                                         generator=torch.Generator().manual_seed(0))

In [None]:
print("Number of cells for training:", len(train_ds))
print("Number of cells for validation:", len(val_ds))
print("Number of cells for test:", len(test_ds))

In [None]:
bs = 256
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, drop_last=True, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=bs, shuffle=False, drop_last=False)
test_dl = DataLoader(test_ds, batch_size=bs, shuffle=False, drop_last=False)

In [None]:
x, y = train_dl.dataset[0]
print("Input data:", x)
print("Label:     ", y)

## Use autoencoders for single-cell analysis

<figure>
    <center><img src="imgs/autoencoder.png"/></center>
    <center><figcaption>Image source: Eraslan et al. Nat Rev Genet. 2019</figcaption></center>
</figure>

<figure>
    <center><img src="imgs/autoencoder_arch.png"/></center>
    <center><figcaption><b>Autoencoder architecture for CITE-seq data</b></figcaption></center>
</figure>

In [None]:
class LinBnDrop(nn.Sequential):
    """Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers, adapted from fastai."""
    
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=True):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

In [None]:
class Encoder(nn.Module):
    """Encoder for CITE-seq data"""
    
    def __init__(self,
                 nfeatures_rna: int,
                 nfeatures_pro: int,
                 hidden_rna: int,
                 hidden_pro: int,
                 latent_dim: int,
                 p: float = 0):
        super().__init__()
        self.nfeatures_rna = nfeatures_rna
        self.nfeatures_pro = nfeatures_pro
        hidden_dim = hidden_rna + hidden_pro
        
        self.encoder_rna = nn.Sequential(
            LinBnDrop(nfeatures_rna, nfeatures_rna // 2, p=p, act=nn.LeakyReLU()),
            LinBnDrop(nfeatures_rna // 2, hidden_rna, act=nn.LeakyReLU())
        )
        self.encoder_protein = LinBnDrop(nfeatures_pro, hidden_pro, p=p, act=nn.LeakyReLU())
        self.encoder = LinBnDrop(hidden_dim, latent_dim, act=nn.LeakyReLU())

    def forward(self, x):
        x_rna = self.encoder_rna(x[:, :self.nfeatures_rna])
        x_pro = self.encoder_protein(x[:, self.nfeatures_rna:])
        x = torch.cat([x_rna, x_pro], 1)
        return self.encoder(x)

In [None]:
class Decoder(nn.Module):
    """Decoder for CITE-seq data"""
    def __init__(self,
                 nfeatures_rna: int,
                 nfeatures_pro: int,
                 hidden_rna: int,
                 hidden_pro: int,
                 latent_dim: int):
        super().__init__()
        hidden_dim = hidden_rna + hidden_pro
        out_dim = nfeatures_rna + nfeatures_pro
        
        self.decoder = nn.Sequential(
            LinBnDrop(latent_dim, hidden_dim, act=nn.LeakyReLU()),
            LinBnDrop(hidden_dim, out_dim // 2, act=nn.LeakyReLU()),
            LinBnDrop(out_dim // 2, out_dim, bn=False)
            )

    def forward(self, x):
        return self.decoder(x)

In [None]:
class CiteAutoencoder(pl.LightningModule):
    def __init__(self,
                 nfeatures_rna: int,
                 nfeatures_pro: int,
                 hidden_rna: int,
                 hidden_pro: int,
                 latent_dim: int,
                 p: float = 0,
                 lr: float = 0.1):
        """ Autoencoder for citeseq data """
        super().__init__()
        
        # save hyperparameters
        self.save_hyperparameters()
 
        self.encoder = Encoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim, p)
        self.decoder = Decoder(nfeatures_rna, nfeatures_pro, hidden_rna, hidden_pro, latent_dim)
        
        # example input array for visualizing network graph
        self.example_input_array = torch.zeros(256, nfeatures_rna + nfeatures_pro)

    def forward(self, x):
        # extract latent embeddings
        z = self.encoder(x)
        return z
    
    def _get_reconstruction_loss(self, batch):
        """ Calculate MSE loss for a given batch. """
        x, _ = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        # MSE loss
        loss = F.mse_loss(x_hat, x)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
    
    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)
        
    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

In [None]:
def train_citeseq(latent_dim: int):
    trainer = pl.Trainer(default_root_dir=CHECKPOINT_PATH,
                         gpus=1 if "cuda" in str(device) else 0,
                         max_epochs=50,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor="val_loss"),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True
    trainer.logger._default_hp_metric=None
    
    model = CiteAutoencoder(nfeatures_rna,
                            nfeatures_pro,
                            hidden_rna=30,
                            hidden_pro=18,
                            latent_dim=latent_dim,
                            p=0.1,
                            lr=0.1)
    trainer.fit(model, train_dl, val_dl)
     
    val_result = trainer.test(model, val_dl, verbose=False)
    test_result = trainer.test(model, test_dl, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result

In [None]:
model, result = train_citeseq(24)

In [None]:
%tensorboard --port 6006 --logdir saved_models/lightning_logs/version_0

In [None]:
test_encodings = []
test_labels = []
    
model.eval()
with torch.no_grad():    
    for x, y in tqdm(test_dl, desc="Encoding cells"):
        test_encodings.append(model(x.to(model.device)))
        test_labels += y.to(torch.int).tolist()
        
test_embeds = torch.cat(test_encodings, dim=0).cpu().numpy()
test_labels = le.inverse_transform(test_labels)

In [None]:
# run umap for dimensionality reduction and visualization
embeds_umap = umap.UMAP(random_state=0).fit_transform(test_embeds)

In [None]:
# visualize umap
fig = px.scatter(x=embeds_umap[:, 0], y=embeds_umap[:, 1], color=test_labels, width=800, height=600,
                 labels={
                     "x": "UMAP1",
                     "y": "UMAP2",
                     "color": "Cell type"}
                )
fig.show()

In [None]:
# visualization with tensorboard
writer = SummaryWriter("tensorboard/")
writer.add_embedding(test_embeds, metadata=test_labels)

In [None]:
%tensorboard --port 6007 --logdir  tensorboard/

In [None]:
writer.close()