# Auto-encoder

In their paper, [Badsha et al.](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7144625/) 
propose a fairly elegant scheme for an auto-encoder for single-cell RNAseq data imputation.


![Figure 1A of Badsha et al. 2020](images/autoencoder.png)
*Figure 1A of Badsha et al. 2020*

In order to reproduce their work, 
first we are going to implement a simple auto-encoder for the gene expression data. 

From there we will see how we can adapt the loss function to focus the learning on the signal in the data (rather than the noise, which is the missing data here). 

Here is their code for inpiration: https://github.com/audreyqyfu/LATE/tree/master

In [None]:
## on google colab, you will have to run the following line:
#!pip install pytorch-model-summary
#!wget https://github.com/Bjarten/early-stopping-pytorch/raw/refs/heads/main/early_stopping_pytorch/early_stopping.py
#!mv early_stopping.py pytorchtools.py

In [None]:
import gc

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA


import torch
from torch import nn
import pytorch_model_summary as pms 

from torch.utils.data import TensorDataset, DataLoader

from pytorchtools import EarlyStopping

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")


In [None]:
import psutil
import os
def usage():
    '''return RAM usage in Mb'''
    process = psutil.Process(os.getpid())
    ram = process.memory_info()[0] / float(2 ** 20)
    ram = round(ram, 1)
    return ram
usage()

In [None]:
## on google colab, you will have to run the following line:
#!wget https://github.com/sib-swiss/pytorch-practical-training/raw/refs/heads/master/data/single_cell/example.hd5
#!wget https://github.com/sib-swiss/pytorch-practical-training/raw/refs/heads/master/data/single_cell/example.cellType.csv
# and adapt the cells below to point to the files in the current directory

In [None]:

fname_input = "data/single_cell/example.hd5"
orientation = 'cell_row'  # cell_row/gene_row

In [None]:
df_tmp = pd.read_hdf(fname_input)


number_0 = (df_tmp != 0).sum().sum()

print("shape is {}".format(df_tmp.shape))
print('non-zero count is {}'.format( number_0 ))
print('non-zero rate  is {:.3f}'.format(number_0 / df_tmp.size ))

In [None]:
# to ease the analysis, we also have a cell type label, extracted from 
# https://github.com/10XGenomics/single-cell-3prime-paper/blob/master/pbmc68k_analysis/68k_pbmc_barcodes_annotation.tsv

cell_types = pd.read_csv('data/single_cell/example.cellType.csv' , index_col=0)
cell_types.celltype.value_counts()

In [None]:
## log10 transformation 
pseudocount = 1

input_df = np.log10( df_tmp.transpose() + pseudocount ).transpose()



In [None]:
m, n = input_df.shape  # m: n_cells; n: n_genes
print('input_matrix: {} cells, {} genes\n'.format(m, n))

print("memory usage: {}Mb".format(usage()))

Let's separate training and validation set

In [None]:
valid_fraction = 0.3
valid_size = int( m * valid_fraction )
train_size = m - valid_size


np.random.seed(1884)
arr = np.arange(m)
np.random.shuffle(arr)

X_train = input_df.iloc[ arr[:train_size] , : ].to_numpy()
X_valid = input_df.iloc[ arr[train_size:] , : ].to_numpy()

print('train: {}'.format(train_size))
print('valid: {}'.format(valid_size))

In [None]:
cell_type_train = list( cell_types.loc[ input_df.index[ arr[:train_size] ] , 'celltype' ] )
cell_type_valid = list( cell_types.loc[ input_df.index[ arr[train_size:] ] , 'celltype' ] )

In [None]:
gene_ids = input_df.columns

train_cell_ids = input_df.index[ arr[:train_size] ]
valid_cell_ids = input_df.index[ arr[train_size:] ]


In [None]:
gc.collect()
print("memory usage: {}Mb".format(usage()))

In [None]:
%%time
pca_valid = PCA().fit( X_valid )
x_pca = pca_valid.transform( X_valid )
pca_valid.explained_variance_ratio_[:10]

In [None]:
np.cumsum( pca_valid.explained_variance_ratio_ )[ 100 ]

In [None]:
%%time

tsne = TSNE(n_components=2)
tsne.fit( x_pca[:,:100] )

In [None]:
fig,ax = plt.subplots(figsize=(12,8))
sns.scatterplot( x = tsne.embedding_[:,0],
               y = tsne.embedding_[:,1],
               hue = cell_type_valid,ax=ax)

## build the data loaders

**exercise:** build the dataLoaders, with a batch size of 256

In [None]:
batch_size = 256

## hint : in an autoencoder X is also the target!

# create your dataset
train_dataset = ...


## creating a dataloader
train_dataloader = ...

# create your dataset
valid_dataset = ...

## creating a dataloader
valid_dataloader = ...


In [None]:
# %load solutions/AE_dataload.py

# simple autoencoder

## model building

Here we the original paper which uses only 2 layers for the encoder and decoder so we'll follow this design.

Architecture:
 - encoder: 
        - layer input size > hidden size
        - layer hidden size > latent space size
 - decoder:
        - layer latent space size > hidden size
        - layer hidden size > layer input size
    

layer structure : Dropout > linear > ReLU
     


**exercise:** implement the simple auto-encoder with the following specifications:

In [None]:
input_dim = 949  
hidden_dim=[500] 
latent_dim = 100 
[input_dim] + hidden_dim + [latent_dim] + hidden_dim + [input_dim]

In [None]:

class Simple_AutoEncoder(torch.nn.Module):
    ...


In [None]:
### test your model with this line:
print(pms.summary(model, torch.zeros(1,949).to(device), show_input=True))

In [None]:
X_train

In [None]:
sample = 1 cell = 949 values

---

In [None]:
# %load solutions/AE_model.py

In [None]:
### test your model with this line:
print(pms.summary(model, torch.zeros(1,949).to(device), show_input=True))

Our loss at this stage will be the Mean Squared Error between the input and the output:

In [None]:

model.eval()
x, = valid_dataset[:5] ## let's go with a batch of 5 samples

mseloss = nn.MSELoss()

with torch.no_grad(): ## disables tracking of gradient: prevent accidental training + speeds up computation
    x = x.to(device)
    pred = model(x)
    print( "input shape:", x.shape)
    print( "prediction shape:", pred.shape)
    print("mean squared error:", mseloss(pred,x))

In [None]:
## get the lower dimensional view of a data point:
model.encode(x[0])

In [None]:
x[0]

## training the model

In [None]:
def train(dataloader, model, loss_fn, optimizer , patience = 10 ,  echo = True , echo_batch = False):
    
    size = len(dataloader.dataset) # how many batches do we have
    model.train() #     Sets the module in training mode.
    
    for batch, (X,) in enumerate(dataloader): # for each batch
        X = X.to(device) # send the data to the GPU or whatever device you use for training

        # Compute prediction error
        pred = model(X)              # prediction for the model -> forward pass
        loss = loss_fn(pred, X)      # loss function from these prediction        
        
        # Backpropagation
        loss.backward()              # backward propagation 
        #                            https://ml-cheatsheet.readthedocs.io/en/latest/backpropagation.html
        #                            https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html
        
        optimizer.step()             
        optimizer.zero_grad()        # reset the gradients
                                     # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch

        if echo_batch:
            current =  (batch) * dataloader.batch_size +  len(X)
            print(f"Train loss: {loss.item():>7f}  [{current:>5d}/{size:>5d}]")
    
    if echo:
        print(f"Train loss: {loss.item():>7f}")

    # return the last batch loss
    return loss.item()


In [None]:
def valid(dataloader, model, loss_fn, echo = True):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() #     Sets the module in evaluation mode
    valid_loss = 0
    with torch.no_grad(): ## disables tracking of gradient: prevent accidental training + speeds up computation
        for (X,) in dataloader:
            X = X.to(device)
            pred = model(X)
            valid_loss += loss_fn(pred, X).item()  ## accumulating the loss function over the batches
            
    valid_loss /= num_batches

    if echo:
        print(f"Valid Error: {valid_loss:>8f}")
    ## return the average loss / batch
    return valid_loss


In [None]:
## preamble -> define the model, the loss function, and the optimizer
model = Simple_AutoEncoder(  input_dim = len(gene_ids) , 
                             hidden_dim=[500] ,
                             latent_dim = 50 , 
                             dropout_fraction = 0.05).to(device)


mseloss = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), 
                       lr = 3*10**-4) ## using the learning rate from their code


## container to keep the scores across all epochs
train_scores = []
valid_scores = []


# overfitting can be an issue here. 
# we use the early stopping implemented in https://github.com/Bjarten/early-stopping-pytorch
# initialize the early_stopping object. 
# patience: How long to wait after last time validation loss improved.
early_stopping = EarlyStopping(patience=25, verbose=False)


In [None]:
%%time
## lets do a single round, to learn how long it takes
train_scores.append( train(train_dataloader, 
                           model, 
                           mseloss, 
                           optimizer, 
                           echo = True , echo_batch = True ) )

valid_scores.append( valid(valid_dataloader, 
                           model, 
                           mseloss , 
                           echo = True) )


In [None]:
%%time

epoch = 200



for t in range(epoch):
    echo = t%10==0
    if echo:
        print('Epoch',len(train_scores)+1 )    

    train_scores.append( train(train_dataloader, 
                               model, 
                               mseloss, 
                               optimizer, 
                               echo = echo , echo_batch = False ) )

    valid_scores.append( valid(valid_dataloader, 
                               model, 
                               mseloss , 
                               echo = echo) )

    # early_stopping needs the validation loss to check if it has decresed, 
    # and if it has, it will make a checkpoint of the current model
    early_stopping(valid_scores[-1], model)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
# load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))


In [None]:
plt.plot(train_scores , label = 'train')
plt.plot(valid_scores, label = 'validation')
plt.axvline(np.argmin(valid_scores), linestyle='--', color='r',label='Early Stopping Checkpoint')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('MSE loss')

In [None]:
x, = valid_dataset[:]
valid_encoded = model.encode( x.to(device) )

In [None]:
valid_encoded

In [None]:
%%time

tsne = TSNE(n_components=2)
tsne.fit( valid_encoded.cpu().numpy() )

fig,ax = plt.subplots(figsize=(12,8))
sns.scatterplot( x = tsne.embedding_[:,0],
               y = tsne.embedding_[:,1],
               hue = cell_type_valid,ax=ax)

# from autoencoder to imputer

To go from an autoencoder to a imputer, we will switch the loss function to make it focus on the points where we have some data and disregards points were the input data is null.

In [None]:
a = torch.Tensor([[0,1,0,2,0,3],
                  [1,1,0,0,0,3]]) # X
b = torch.Tensor([[1,2,1,1,1,2],
                  [0,2,1,2,1,2]]) # prediction

In [None]:
## basic Squared Error:
(a-b)**2

In [None]:
a!=0

In [None]:
## Squared Error where 0s in the original data are masked:
non_zero_mask = (a!=0)
(a-b)**2 * non_zero_mask

In [None]:
## and to compute a mean, we sum and divide by the number of non-zeros
SE = torch.sum((a-b)**2 * non_zero_mask)
N0 = torch.sum( non_zero_mask )
SE/N0

In [None]:
def maskedMeanSquareError(output, target):
    non_zero_mask = (target!=0)
    SE = torch.sum((output - target)**2 * non_zero_mask)
    N0 = torch.sum(non_zero_mask)
    return SE/N0
maskedMeanSquareError(b, a)

In [None]:
x, = valid_dataset[:5]
x = x.to(device)
pred = model(x)

mseloss( pred,x )

In [None]:
maskedMeanSquareError(pred, x)

### synthetic dataset for testing

In [None]:
## on google colab you will have to download 
#!wget https://github.com/sib-swiss/pytorch-practical-training/raw/refs/heads/master/data/single_cell/example.msk90.hd5

## and adapt the following cells to read these files from the current directory

In [None]:
## same as the original dataset, but with 90% of 0s
fname_input = "data/single_cell/example.msk90.hd5"
orientation = 'cell_row'  # cell_row/gene_row

In [None]:
df_tmp = pd.read_hdf(fname_input)

number_0 = (df_tmp != 0).sum().sum()

print("shape is {}".format(df_tmp.shape))
print('non-zero count is {}'.format( number_0 ))
print('non-zero rate  is {:.3f}'.format(number_0 / df_tmp.size ))

In [None]:
## log10 transformation 
pseudocount = 1

input_sparse_df = np.log10( df_tmp.transpose() + pseudocount ).transpose()

Let's separate training and validation set

In [None]:
valid_fraction = 0.3
valid_size = int( m * valid_fraction )
train_size = m - valid_size


np.random.seed(1884)
arr = np.arange(m)
np.random.shuffle(arr)

X_train = input_sparse_df.iloc[ arr[:train_size] , : ].to_numpy()
X_valid = input_sparse_df.iloc[ arr[train_size:] , : ].to_numpy()

print('train: {}'.format(train_size))
print('valid: {}'.format(valid_size))

In [None]:
cell_type_train = list( cell_types.loc[ input_sparse_df.index[ arr[:train_size] ] , 'celltype' ] )
cell_type_valid = list( cell_types.loc[ input_sparse_df.index[ arr[train_size:] ] , 'celltype' ] )

In [None]:
gene_ids = input_df.columns

train_cell_ids = input_df.index[ arr[:train_size] ]
valid_cell_ids = input_df.index[ arr[train_size:] ]

In [None]:
gc.collect()
print("memory usage: {}Mb".format(usage()))

In [None]:
%%time
pca_valid = PCA().fit( X_valid )
x_pca = pca_valid.transform( X_valid )
pca_valid.explained_variance_ratio_[:10]

In [None]:
%%time
tsne = TSNE(n_components=2)
tsne.fit( x_pca[:,:100] )
fig,ax = plt.subplots(figsize=(12,8))
sns.scatterplot( x = tsne.embedding_[:,0],
               y = tsne.embedding_[:,1],
               hue = cell_type_valid,ax=ax)

## build the data loaders

In [None]:
batch_size = 256

In [None]:
# create your dataset
train_dataset = TensorDataset( torch.Tensor(X_train) ) 

## creating a dataloader
train_dataloader = DataLoader( train_dataset , batch_size = batch_size ) 

In [None]:
# create your dataset
valid_dataset = TensorDataset( torch.Tensor(X_valid) ) 

## creating a dataloader
valid_dataloader = DataLoader(valid_dataset , batch_size = batch_size )

In [None]:
x, = train_dataset[:]
torch.sum( x == 0 ) / torch.numel( x )

In [None]:
x, = valid_dataset[:]
torch.sum( x == 0 ) / torch.numel( x )

## build and train the model

In [None]:
## preamble -> define the model, the loss function, and the optimizer
model = Simple_AutoEncoder(  input_dim = len(gene_ids) , 
                             hidden_dim=[500] ,
                             latent_dim = 100 , 
                             dropout_fraction = 0.05).to(device)



#############################################
## Here we specify our custom loss function
mseloss = maskedMeanSquareError
#############################################


optimizer = torch.optim.Adam(model.parameters(), 
                       lr = 3*10**-4) ## using the learning rate from their code


## container to keep the scores across all epochs
train_scores = []
valid_scores = []


# overfitting can be an issue here. 
# we use the early stopping implemented in https://github.com/Bjarten/early-stopping-pytorch
# initialize the early_stopping object. 
# patience: How long to wait after last time validation loss improved.
early_stopping = EarlyStopping(patience=25, verbose=False)


In [None]:
%%time
## lets do a single round, to learn how long it takes
train_scores.append( train(train_dataloader, 
                           model, 
                           maskedMeanSquareError, 
                           optimizer, 
                           echo = True , echo_batch = True ) )

valid_scores.append( valid(valid_dataloader, 
                           model, 
                           maskedMeanSquareError , 
                           echo = True) )


In [None]:
%%time

epoch = 200



for t in range(epoch):
    echo = t%10==0
    if echo:
        print('Epoch',len(train_scores)+1 )    

    train_scores.append( train(train_dataloader, 
                               model, 
                               maskedMeanSquareError, 
                               optimizer, 
                               echo = echo , echo_batch = False ) )

    valid_scores.append( valid(valid_dataloader, 
                               model, 
                               maskedMeanSquareError , 
                               echo = echo) )

    # early_stopping needs the validation loss to check if it has decresed, 
    # and if it has, it will make a checkpoint of the current model
    early_stopping(valid_scores[-1], model)

    if early_stopping.early_stop:
        print("Early stopping")
        break
        
# load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))


In [None]:
plt.plot(train_scores , label = 'train')
plt.plot(valid_scores, label = 'validation')
plt.axvline(np.argmin(valid_scores), linestyle='--', color='r',label='Early Stopping Checkpoint')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('MSE loss')

In [None]:
x, = valid_dataset[:]
valid_encoded = model.encode( x.to(device) )

In [None]:
%%time

tsne = TSNE(n_components=2)
tsne.fit( valid_encoded.cpu().numpy() )

fig,ax = plt.subplots(figsize=(12,8))
sns.scatterplot( x = tsne.embedding_[:,0],
               y = tsne.embedding_[:,1],
               hue = cell_type_valid,ax=ax)

Looks like our masked encoder has retrieved some of the data structure despite the heavy sparsity!

In [None]:
## imputation 
with torch.no_grad():
    x, = train_dataset[:]
    train_imputed = model( x.to(device) ).cpu().numpy()
    x, = valid_dataset[:]
    valid_imputed = model( x.to(device) ).cpu().numpy()


In [None]:
## ground truth
df_truth = pd.read_hdf('data/single_cell/example.hd5')
## log10 transformation 
pseudocount = 1
df_truth = np.log10( df_truth.transpose() + pseudocount ).transpose()

In [None]:
truth_train = np.array( df_truth.loc[ train_cell_ids , : ] )
truth_valid = np.array( df_truth.loc[ valid_cell_ids , : ] )

In [None]:
## remember, the "ground truth" data already had ~67% of missing data, 
##  we do not want to focus on this for the solution assessment

train_NI_mask = ( truth_train != 0 ) & ( X_train != 0 ) # non imputed data
train_I_mask = ( truth_train != 0 ) & ( X_train == 0 ) # imputed data with a ground truth

valid_NI_mask = ( truth_valid != 0 ) & ( X_valid != 0 ) # non imputed data
valid_I_mask = ( truth_valid != 0 ) & ( X_valid == 0 ) # imputed data with a ground truth


In [None]:

# difference of non imputed in train
train_non_imputed_diff = ( train_imputed[ train_NI_mask ] - truth_train[ train_NI_mask ] )
# difference of imputed in train
train_imputed_diff = ( train_imputed[ train_I_mask ] - truth_train[ train_I_mask ] )

# difference of non imputed in valid
valid_non_imputed_diff = ( valid_imputed[ valid_NI_mask ] - truth_valid[ valid_NI_mask ] )
# difference of imputed in valid
valid_imputed_diff = ( valid_imputed[ valid_I_mask ] - truth_valid[ valid_I_mask ] )



In [None]:
%%time

sizes = (train_non_imputed_diff.shape[0],
         train_imputed_diff.shape[0],
         valid_non_imputed_diff.shape[0],
         valid_imputed_diff.shape[0])

sns.violinplot(x = np.concatenate( [train_non_imputed_diff ,
                                    train_imputed_diff , 
                                    valid_non_imputed_diff, 
                                    valid_imputed_diff]),
            y = ['train']*(sizes[0]+sizes[1]) + ['valid']*(sizes[2]+sizes[3]),
            hue = ['non-imputed']*sizes[0] + ['imputed']*sizes[1] + ['non-imputed']*sizes[2] + ['imputed']*sizes[3] )

In [None]:

print('train, non-imputed - absolute error:')
V = np.abs(train_non_imputed_diff)
print('\tq0.05: {:.4f} , q0.5: {:.4f} , q0.95: {:.4f}'.format( *(np.quantile(V , [0.05,0.5,0.95])) ))
print('train, imputed - absolute error:')
V = np.abs(train_imputed_diff)
print('\tq0.05: {:.4f} , q0.5: {:.4f} , q0.95: {:.4f}'.format( *(np.quantile(V , [0.05,0.5,0.95])) ))
print('valid, non-imputed - absolute error:')
V = np.abs(valid_non_imputed_diff)
print('\tq0.05: {:.4f} , q0.5: {:.4f} , q0.95: {:.4f}'.format( *(np.quantile(V , [0.05,0.5,0.95])) ))
print('valid, imputed - absolute error:')
V = np.abs(valid_imputed_diff)
print('\tq0.05: {:.4f} , q0.5: {:.4f} , q0.95: {:.4f}'.format( *(np.quantile(V , [0.05,0.5,0.95])) ))