<a href="https://colab.research.google.com/github/tchaase/cVAE_autism/blob/main/code/cVAE_autism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive Variational Autoencoder for the ABIDE Data Set

Author - Tobias Haase

## Imports

Firstly I am importaing the necessary modules here, that I will use within the following.



In [None]:
import torch  # The main PyTorch library for tensor computations and neural network operations

import torch.nn as nn  # Provides various neural network layers and functionalities
import torch.nn.functional as F  # Provides functional interfaces to common operations (e.g., activation functions)
import torch.optim as optim  # Contains various optimization algorithms (e.g., SGD, Adam)

import torchvision  # A PyTorch library for computer vision tasks
import torchvision.transforms as transforms  # Provides common image transformations (e.g., resizing, normalization)
from torchvision.transforms import ToTensor  # Transforms PIL images to tensors
from torch.utils.data import Dataset, DataLoader  # Provides tools for creating custom datasets and data loaders

import numpy as np  # NumPy library for numerical computations and array operations
import matplotlib  # Matplotlib library for data visualization
import matplotlib.pyplot as plt  # Matplotlib's pyplot module for creating plots
from tqdm import tqdm  # Progress bar library for tracking iterations


Next, let's load the data. I am loading the data using nilearn's `fetch_abide_pcp` function. This function allows me to load the data that was previously preprocessed via the [preprocessed connectom project](http://preprocessed-connectomes-project.org/index.html) (PCP). Within this project, the data was preprocessed with four different pipelines.
  

>Due to the controversies surrounding bandpass filtering and global signal regression, four different preprocessing strategies were performed with each pipeline: all combinations of with and without filtering and with and without global signal correction.

So, the first question to answer is which preprocessing pipeline I should take. Let's go over them step by step. I tried listing what data they focus on during the preprocessing. Then I want to briefly list key features that set them apart from other pipelines. Under dependencies I list mostly the dependencies they had during their usage, not what they require to load the data with!

1. [Connectome Computation System](http://preprocessed-connectomes-project.org/abide/ccs.html):
  * Preprocessing Steps: CCS involves the usual preprocessing steps, in which both the structural and functional data is preprocessed.
  * Key Features: Perhaps it is important to note that this pipeline integrates FSL and Freesurfer and is primarily implemented using bash but also using various other programming languages.
  * Dependencies: Therefore, this pipeline depends on FSL (skull stripping, normalization etc), freesurfer (e.g. anatomical segmentation, surface reconstruction) and AFNI (various preprocessing tools come from here)
2. [Configurable Pipeline for the Analysis of Connectomes](http://preprocessed-connectomes-project.org/abide/cpac.html):
    * Preprocessing Steps: CPAC incorporates a range of preprocessing steps for both structural and functional data. This includes motion correction, slice timing correction, spatial normalization, intensity normalization, nuisance signal regression, and band-pass filtering.
    * Key Features: Most importantly, CPAC offers a high level of configuration as the name suggests. This allows the choice of several processing options based on their study requirements. It provides various quality control measures and outputs, including preprocessed functional connectivity matrices!
    * Dependencies: CPAC is primarily implemented in Python and relies on various libraries and tools such as Nipype, FSL, ANTS, and AFNI.
3. [Data Processing Assistant for Resting-State fMRI](http://preprocessed-connectomes-project.org/abide/dparsf.html):
    * Preprocessing Steps: DPARSF focuses on resting-state functional MRI data and includes standard preprocessing steps such as slice timing correction, realignment (motion correction), spatial normalization, smoothing, and nuisance signal regression.
    * Key Features: DPARSF provides a graphical user interface. There is a certian level of configurability, as ouput options can be choosen.
    * Dependencies: DPARSF is implemented in MATLAB and requires SPM (Statistical Parametric Mapping) toolbox for some of the preprocessing steps.
4. [Neuroimaging Analysis Kit](http://preprocessed-connectomes-project.org/abide/niak.html)
    * Preprocessing Steps: NIAK allows customization of preprocessing steps, including motion correction, slice timing correction, spatial normalization, smoothing, and nuisance signal regression. It also offers quality control measures.
    * Key Features: NIAK provides a flexible and versatile pipeline for functional and structural MRI data. It offers a command-line interface and the ability to select specific processing options based on the research requirements.
    * Dependencies: NIAK is primarily implemented in MATLAB and relies on various external software packages such as FSL, ANTS, and AFNI for specific preprocessing steps.

My sources for this information are both the website and ChatGPT.

It seems to me that I can stick with the preset pipeline for now, which is **cpac**.

Importantly, quality control was already performed for this data, and I will only load the data that has gone through the quality control successfully.

For now, I am just loading one participant.

In [None]:
nilearn.datasets.fetch_abide_pcp(data_dir = "./data", n_subjects = 1)

## Model specifications

In the following I am specifiying the model. I am roughly orienting myself around a paper from Anglinkas, Hartshorne & Anzellotti (2022).

### Defining utility functions

Firstly, I am defining the loss function.
The loss will be computed as the sum of the BCE-Loss, as well as the KL-divergence terms.

* BCE-loss: Reconstruction loss. This is the binary cross entropy, i.e. a loss function that is normally used when making a binary classification. Using it here as a placeholder, until I find a more appropriate loss function. If a prediction is incorrect, it has high values.

* Kullback-Leibler divergence (Kullback & Leibler, 1951) This is a measure for the difference between two distributions. I.e. "how much do they diverge" from each other, how much are they different to each other. The introduction of this term into the final loss leads my model to optimize not only if the precited categories are correct and so on, but also how high the difference between the prior distribution and teh latent variables are. The prior distribution in my case is an isotropic gaussian.
  * Why is this desirable? The latent variables and the sampling process should be somewhat controlled. This divergence regulates this.


I have also attempted to regulate that a loss is only completed with the KL divergence from the second encoder if that encoder was used.

In [None]:
def final_loss(bce_loss, z_mu, z_logvar, s_mu=None, s_logvar=None):
    """
    This function will add the reconstruction loss (BCELoss) and the KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: reconstruction loss
    :param z_mu: mean from the latent vector of encoder_z
    :param z_logvar: log variance from the latent vector of encoder_z
    :param s_mu: mean from the latent vector of encoder_s (optional)
    :param s_logvar: log variance from the latent vector of encoder_s (optional)
    """
    BCE = bce_loss
    KLD_z = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
    if s_mu is not None and s_logvar is not None:
        KLD_s = -0.5 * torch.sum(1 + s_logvar - s_mu.pow(2) - s_logvar.exp())
        return BCE + KLD_z + KLD_s
    else:
        return BCE + KLD_z

Next onto the training function. This should be fairly straightforward.

In [None]:
def train(model, dataloader, dataset, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, (data, is_autism) in tqdm(enumerate(dataloader), total=len(dataset)):
        data = data.to(device)
        autism_status = is_autism.to(device)

        optimizer.zero_grad()

        for j in range(data.size(0)):
            single_data = data[j]
            single_autism_status = autism_status[j]

            if single_autism_status:
                z_mean, z_log_var, s_mean, s_log_var, reconstructed_data = model(single_data)
                bce_loss = criterion(reconstructed_data, single_data)
                loss = final_loss(bce_loss, z_mean, z_log_var, s_mean, s_log_var)
            else:
                z_mean, z_log_var, reconstructed_data = model(single_data)
                bce_loss = criterion(reconstructed_data, single_data)
                loss = final_loss(bce_loss, z_mean, z_log_var, None, None)

            loss.backward()
            running_loss += loss.item()

        optimizer.step()

        counter += data.size(0)

    train_loss = running_loss / counter
    return train_loss

Following this, the model should also be validated. This is placeholder code for now (also didn't finish the training function).

In [None]:
def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            data = data[0]
            data = data.to(device)
            reconstruction, z_mu, z_logvar, s_mu, s_logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, z_mu, z_logvar, s_mu, s_logvar)
            running_loss += loss.item()

            # save the last batch input and output of every epoch
            if i == int(len(dataset)/dataloader.batch_size) - 1:
                recon_images = reconstruction
    val_loss = running_loss / counter
    return val_loss, recon_images


## Model specification

These values still need to be adapted for the current model.

In [None]:
init_channels = 64 # initial number of filters, first layers output.
image_channels = 1 # MNIST images are grayscale
latent_dim = 16 # latent dimension for sampling
filters

kernel_size = 3
stride = 2
same = 0
padding = same

lr = 0.001

indermediate_dim = 128

Next I want to define the contrastive variational autoencoder. While doing so, I am defining seperate encoders, to make it easier to later introduce other encoders. I am orienting myself on an cVAE I have written in the past.

As the paper from Aglinskas, Hartshorne and Anzellotti (2022) I mentioned, the network will have only a few layers.

A few things I will probably have to change - I do not know how many channels the data will end up having. therefore I am using one, assuming it only has one.

In [None]:
class EncoderNS(nn.Module):
    def __init__(self, latent_dim):
        super(EncoderNS, self).__init__()
        self.shared_conv1 = nn.Conv3d(in_channels=image_channels, out_channels=init_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.shared_conv2 = nn.Conv3d(in_channels=init_channels, out_channels=init_channels * 2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.linear = nn.Linear(None, intermediate_dim)
        self.ns_fc_mean = nn.Linear(intermediate_dim, latent_dim)
        self.ns_fc_log_var = nn.Linear(intermediate_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.shared_conv1(x))
        h = F.relu(self.shared_conv2(h))
        h = h.view(h.size(0), -1)
        h = F.relu(self.linear(h))
        ns_mean = self.ns_fc_mean(h)
        ns_log_var = self.ns_fc_log_var(h)
        return ns_mean, ns_log_var


class EncoderS(nn.Module):
    def __init__(self, latent_dim):
        super(EncoderS, self).__init__()
        self.specific_conv1 = nn.Conv3d(in_channels=image_channels, out_channels=init_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.specific_conv2 = nn.Conv3d(in_channels=init_channels, out_channels=init_channels * 2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.linear = nn.Linear(None, intermediate_dim)
        self.s_fc_mean = nn.Linear(intermediate_dim, latent_dim)
        self.s_fc_log_var = nn.Linear(intermediate_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.specific_conv1(x))
        h = F.relu(self.specific_conv2(h))
        h = h.view(h.size(0), -1)
        h = F.relu(self.linear(h))
        s_mean = self.s_fc_mean(h)
        s_log_var = self.s_fc_log_var(h)
        return s_mean, s_log_var


class Decoder(nn.Module)
    def __init__(self, latent_dim)
      super(Decoder, self).__init__()
      self.linear_decoder_1 = nn.Linear(latent_dim, intermediate_dim)
      self.linear_decoder_2 = nn.Linear(#? dont know yet depends if we want to get 3d input and what size)
      self.conv_decoder_1 = nn.ConvTranspose2d(in_channels = filters, out_channels = filters,  kernel_size = kernel_size, stride = stride, padding = padding)
      self.conv_decoder_2 = nn.ConvTranspode2d(in channels = filters, out_channels = 1, kernel_size = kernel_size, stirde = stride, padding = padding)

    def forward(self, s_ns)
      h_output = F.relu(self.linear_decoder_1(s_ns))
      h_output = F.relu(self.linear_decoder_2(h_output))
      h_output = nn.Unflatten(1, #depends on what we work with)
      h_output = F.relu(self.conv_decoder_1(h_output))
      output = F.Sig(self.conv_decoder_2(h_output)) #choice of activation function depends on my choice of a loss function
      return output


class cVAE(nn.Module):
    def __init__(self, latent_dim):
        super(cVAE, self).__init__()
        self.encoder_ns = EncoderNS(latent_dim)
        self.encoder_s = EncoderS(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.autism = None


    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mean + epsilon * std


    def forward(self, x):
        ns_mean, ns_log_var = self.encoder_ns(x)
        ns = self.reparameterize(ns_mean, ns_log_var)

        if self.autism:
            s_mean, s_log_var = self.encoder_s(x)
            s = self.reparameterize(s_mean, s_log_var)
            s_ns = torch.cat([ns, s], dim=1)
        else:
            empty_vector = torch.empty_like(ns).detach()
            s_ns = torch.cat([ns,empty_vector], dim = 1)

        reconstructed_data = self.decoder(s_ns)

        if self.autism:
            return ns_mean, ns_log_var, s_mean, s_log_var, reconstructed_data
        else:
            return ns_mean, ns_log_var, reconstructed_data


Upcoming: I still need to define the training and validation function and all of the analyses.
