In [None]:
import jax
import jaxlib

print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)

import scvi
import h5py
import anndata as ad
import pandas as pd
import scipy.sparse as sp
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from scipy.sparse import issparse

#Read in our AnnData
atse_anndata = ad.read_h5ad('/Users/smriti/documents/Research_Knowles_Lab/multivi_tools_splicing/ann_data/ATSE_Anndata_Object_BRAIN_only_20241105_wLeafletFAPSIs.h5ad', backed = 'r')

In [None]:
atse_anndata.layers['junc_ratio']

In [None]:
#AnnDataDataset Class helps to store layers as a dictionary that contain tensors for each layer so that we can batch them. 
class AnnDataDataset(Dataset):
    def __init__(self, layer_tensors):
        self.layer_tensors = layer_tensors
        self.num_samples = list(layer_tensors.values())[0].shape[0]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return {layer_key: tensor[idx] for layer_key, tensor in self.layer_tensors.items()}

In [None]:
#Encoder Class - takes inputs and outputs latent variables mean and variance
class Encoder(nn.Module):
    def __init__(self, input_dim, num_hidden_layers, num_hidden_units, latent_dim, dropout_rate = 0.0):
        super().__init__()
        #consists of 1 input layer, then a series of hidden layers (which can be defined in the params), and 1 output layer
        #takes input and outputs TWO latent representation parameters with latent_dim number of dimensions, means and variances for our Normal Distribution
        #latent representation

        #input layer includes relu and dropout
        self.input = nn.Sequential(nn.Linear(input_dim, num_hidden_units[0]), nn.ReLU(), nn.Dropout(dropout_rate))
        self.hidden_layers = []
        #adding linear layers, each with relu and dropout. dimensions are defined in num_hidden_units
        for i in range (num_hidden_layers):
            self.hidden_layers.append(nn.Sequential(nn.Linear(num_hidden_units[i], num_hidden_units[i+1]), nn.ReLU(), nn.Dropout(dropout_rate)))
        #we output the two different latent rep parameters here. 
        self.output_means = nn.Linear(num_hidden_units[len(num_hidden_units) - 1], latent_dim)
        self.output_log_vars = nn.Linear(num_hidden_units[len(num_hidden_units) - 1], latent_dim)

    #forward pass through all our layers
    def forward(self, x):
        x = self.input(x)
        for layer in self.hidden_layers:
             x = layer(x)
        means = self.output_means(x)
        log_vars = self.output_log_vars(x)
        return means, log_vars

#Decoder Class takes in reparametrized latent representation (z) and creates a logit reconstruction
class Decoder(nn.Module):
    def __init__(self, z_dim, num_hidden_layers, num_hidden_units, output_dim, dropout_rate = 0.0):
        super().__init__()
        #similar to encoder, except input dim = latent dim from earlier
        self.input = nn.Sequential(nn.Linear(z_dim, num_hidden_units[-1]), nn.ReLU(), nn.Dropout(dropout_rate))
        self.hidden_layers = []
        #in same way as in encoder we add the hidden layers except in reverse order for the dimensions
        for i in reversed(range(num_hidden_layers)):
            self.hidden_layers.append(nn.Sequential(nn.Linear(num_hidden_units[i+1], num_hidden_units[i]), nn.ReLU(), nn.Dropout(dropout_rate)))
        #output a raw logit representing the reconstruction
        self.output = nn.Linear(num_hidden_units[0], output_dim)

    #forward pass through all our layers
    def forward(self, x):
        x = self.input(x)
        for layer in self.hidden_layers:
             x = layer(x)
        reconstruction = self.output(x)
        return reconstruction

#Binomial loss function includes reconstruction loss and KL divergence
def binomial_loss_function(reconstruction, junction_counts, mean, log_vars, n_cluster_counts, eps=1e-04):
    log1p_exp_logits = torch.logaddexp(torch.zeros_like(reconstruction), reconstruction) #perturb the logits a bit with an epsilon value
    loglik = (junction_counts * reconstruction) - (n_cluster_counts * log1p_exp_logits) #get the log likelihood
    reconstruction_loss = -loglik.mean() #binomial loss
    kl_divergence = -0.5 * torch.sum(1 + log_vars - mean.pow(2) - log_vars.exp()) #with respect to N(0,I) prior, as MultiVI does
    total_loss = reconstruction_loss + kl_divergence
    return total_loss

#Takes the ATSE AnnData and convert it into a tensor 
def construct_input_dataloaders(atse_anndata, batch_size):
    layer_tensors = {
        layer_key: torch.tensor(atse_anndata.layers[layer_key].toarray() if issparse(atse_anndata.layers[layer_key]) else atse_anndata.layers[layer_key], dtype=torch.float32)
        for layer_key in list(atse_anndata.layers.keys())[:3] #dictionary of PyTorch tensors that are derived from the first three layers of the AnnData object
    } #if sparse matrix, then we convert to a numpy array, and we also make sure to cast it to float32 or else pytorch gets mad at me
    dataset = AnnDataDataset(layer_tensors) #put it into our new dataset class defined earlier
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) #send it to our dataloader and we're off!
    #will need to implement validation set and test set probably so I can do early stopping and do some umaps to visualize how the cells are in latent space
    return dataloader

#Our main VAE model! Takes inputs, encodes them as latent parameters (mean, var), reparametrizes them (z), then decodes them as logit reconstructions
#loss is based on how well z can be used to "reconstruct" the junction counts, given the ATSE counts. in this way, z is sort of like a denoised representation
#of the junction usage ratios (?)
class VAE(nn.Module):
    def __init__(self, input_dim, num_hidden_layers, num_hidden_units, latent_dim, output_dim, dropout_rate = 0.0):
        super().__init__()
        #creating our encoder + decoder given our VAE parameters. 
        self.encoder = Encoder(input_dim, num_hidden_layers, num_hidden_units, latent_dim, dropout_rate)
        self.decoder = Decoder(latent_dim, num_hidden_layers, num_hidden_units, output_dim, dropout_rate)
    
    #reparametrization trick. typically, we'd need to draw z directly from the MVN to sample our latent variable
    #as z = N(mean, var)
    #but this makes it hard to calculate the gradient later because its stochastic.
    #so instead, we represent z as z=mean+ stdâ‹…eps where 
    #our epsilon is noise to introduce a bit of variability 
    def reparametrize(self, mean, log_vars):
        std = torch.exp(0.5 * log_vars) 
        eps = torch.randn_like(std)
        return mean + eps * std  
    
    #forward pass through our layers
    def forward(self, x):
        mean, log_vars = self.encoder(x) #get the mean and var from the encoder
        z = self.reparametrize(mean, log_vars) #reparametrize to get our z (latent representation)
        reconstruction = self.decoder(z) #feed this latent rep directly into our decoder to get our reconstructed logit

        reconstruction = reconstruction.to(torch.float32) #make sure its a float!
        mean = mean.to(torch.float32) #make this this is also a float (pytorch kept yelling at me)
        log_vars = log_vars.to(torch.float32)
        
        return reconstruction, mean, log_vars
    
    #finally, we train our model!
    def train_model(self, train_dataloader, num_epochs, learning_rate):
        print("Beginning Training")
        #using adam optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        train_loss = 0
        for epoch in range(num_epochs):
            epoch_loss = 0
            self.train() #putting our model in train mode
            for batch in train_dataloader:
                optimizer.zero_grad()
                reconstruction, mean, log_vars = self.forward(batch["junc_ratio"]) #calling forward on our junction ratios from the batch
                loss = binomial_loss_function(reconstruction, batch["cell_by_junction_matrix"], mean, log_vars, batch["cell_by_cluster_matrix"]) #getting our
                #loss, giving the loss function all necessary parameters including the junction and cluster counts
                loss.backward() #backward pass
                epoch_loss += loss.item()
                optimizer.step()
            print(f"Epoch {epoch+1} of {num_epochs}; Train Loss = {epoch_loss/len(train_dataloader)}")
            train_loss += epoch_loss/len(train_dataloader)
        return train_loss/num_epochs

In [None]:
#using Karin's parameters from her VAE for now, will change once I run this on the cluster.
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10
BATCH_SIZE = 128
USE_CUDA = torch.cuda.is_available()
INPUT_DIM = atse_anndata.var.shape[0]
NUM_HIDDEN_LAYERS = 1
HIDDEN_DIMS = [128, 64] 
LATENT_DIM = 20
OUTPUT_DIM = INPUT_DIM 
#getting our dataloader from the atse anndata, given our batch size. 
dataloader = construct_input_dataloaders(atse_anndata, BATCH_SIZE)

In [None]:
#training call! puts it on our GPU if available
model = VAE(INPUT_DIM, NUM_HIDDEN_LAYERS, HIDDEN_DIMS, LATENT_DIM, OUTPUT_DIM)
if USE_CUDA:
    model.cuda()

model.train_model(dataloader, NUM_EPOCHS, learning_rate = 1e-3)