<a href="https://colab.research.google.com/github/tchaase/cVAE_autism/blob/main/code/cVAE_autism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive Variational Autoencoder for the ABIDE Data Set

Author - Tobias Haase

## Imports

Firstly I am importaing the necessary modules here, that I will use within the following.



In [3]:
import torch  # The main PyTorch library for tensor computations and neural network operations

import torch.nn as nn  # Provides various neural network layers and functionalities
import torch.nn.functional as F  # Provides functional interfaces to common operations (e.g., activation functions)
import torch.optim as optim  # Contains various optimization algorithms (e.g., SGD, Adam)

import torchvision  # A PyTorch library for computer vision tasks
import torchvision.transforms as transforms  # Provides common image transformations (e.g., resizing, normalization)
from torchvision.transforms import ToTensor  # Transforms PIL images to tensors
from torch.utils.data import Dataset, DataLoader  # Provides tools for creating custom datasets and data loaders
import torch.nn as nn

import numpy as np  # NumPy library for numerical computations and array operations
import matplotlib  # Matplotlib library for data visualization
import matplotlib.pyplot as plt  # Matplotlib's pyplot module for creating plots
from tqdm import tqdm  # Progress bar library for tracking iterations

import os
import requests
import nibabel as nib
import numpy as np
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cuda')

Next, I am loading the project's data. To load the data, I am using CyberDuck, and I am loaded the already preprocessed cortical thickness data.

Firstly, I need to install **CyberDuck**:


In [5]:
!echo -e "deb https://s3.amazonaws.com/repo.deb.cyberduck.io stable main" | sudo tee /etc/apt/sources.list.d/cyberduck.list > /dev/null
!sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys FE7097963FEFBE72
!sudo apt-get update
!sudo apt-get install duck

Executing: /tmp/apt-key-gpghome.Q0odrWal6P/gpg.1.sh --keyserver keyserver.ubuntu.com --recv-keys FE7097963FEFBE72
gpg: key F7FAE1F32DA69515: public key "Cyberduck <feedback@cyberduck.io>" imported
gpg: Total number processed: 1
gpg:               imported: 1
Hit:1 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
Get:6 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Hit:7 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:8 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [119 kB]
Hit:9 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:10 https://s3.amazonaws.com/repo.deb.cyb

In [6]:
#!ls ./data/anat_thickness/
!rm -rf ./data

Next let's download the data that I am interested in.

In [7]:
!mkdir -p ./data/anat_thickness  # Creating a directory before, to avoid errors
!mkdir -p ./data/roi_thickness
# Code to load all participants
# !duck --username anonymous --verbose --download s3:/fcp-indi/data/Projects/ABIDE_Initiative/Outputs/ants/anat_thickness/*_anat_thickness.nii.gz ./data/anat_thickness

# Code to load the ROIs
!duck --username anonymous --verbose --download s3:/fcp-indi/data/Projects/ABIDE_Initiative/Outputs/ants/roi_thickness/CMU_a_0050653_roi_thickness.txt ./data/roi_thickness/

# Code to load the 3d volume.
#!duck --username anonymous --verbose --download s3:/fcp-indi/data/Projects/ABIDE_Initiative/Outputs/ants/anat_thickness/CMU_a_0050653_anat_thickness.nii.gz ./data/anat_thickness/

[36m[s[2K[uReading metadata of CMU_a_0050653_roi_thickness.txt…[m[36m[s[2K[uResolving s3.amazonaws.com…[m[36m[s[2K[uOpening S3 connection to s3.amazonaws.com…[m[36m[s[2K[uS3 connection opened…[m[36m[s[2K[uLogin successful…[m
[32m> HEAD /data/Projects/ABIDE_Initiative/Outputs/ants/roi_thickness/CMU_a_0050653_roi_thickness.txt HTTP/1.1[m
[32m> Date: Mon, 31 Jul 2023 15:36:46 GMT[m
[32m> Host: fcp-indi.s3.amazonaws.com:443[m
[32m> Connection: Keep-Alive[m
[32m> User-Agent: Cyberduck/8.6.0.39818 (Linux/5.15.109+) (amd64)[m
[32m> Accept-Encoding: gzip,deflate[m
[31m< HTTP/1.1 200 OK[m
[31m< x-amz-id-2: aNsnIyQEU89b+CHdG4YdoAySYbANHU/DHmcTztTHM6wA2MKzz9oHb5RnSOulmnrmRhYVPXr35GHupvAAXm7u5FwvVCSXI3U4[m
[31m< x-amz-request-id: 3G5CKMJ1ZPKSA30D[m
[31m< Date: Mon, 31 Jul 2023 15:36:48 GMT[m
[31m< Last-Modified: Mon, 17 Oct 2016 18:30:24 GMT[m
[31m< ETag: "f5f9858d88d004f660a43abf7c0bacc3"[m
[31m< x-amz-version-id: null[m
[31m< Accept-Ranges

Next, let's download the file with the participant info:

In [8]:
# URL to download the CSV file
csv_url = "https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv"  # Replace with the actual URL

# Directory to store the CSV file
data_directory = "./data/participant_info"

# Create the directory if it does not exist
os.makedirs(data_directory, exist_ok=True)

# File path to save the CSV file
csv_file_path = os.path.join(data_directory, "participant_info.csv")

# Download the CSV file
response = requests.get(csv_url)
if response.status_code == 200:
    with open(csv_file_path, "wb") as f:
        f.write(response.content)
    print("CSV file downloaded successfully.")
else:
    print("Failed to download the CSV file.")


CSV file downloaded successfully.


In [9]:
!ls ./data/anat_thickness/

Now, I have two options currently, eihter I load the participants data via the 3d image and overlay an atlas manually, or I use predefined labels.

In [10]:
# Directory containing your text files
data_directory = "./data/roi_thickness/"

# Read the participant information from the CSV file
csv_file = "./data/participant_info/participant_info.csv"
participant_info_df = pd.read_csv(csv_file)

# Create dictionaries to store data and participant information for autism and non-autism participants
data_info_dict_autism = {}
data_info_dict_no_autism = {}

# Loop through each text file
for file_name in os.listdir(data_directory):
    # Check if the file is a text file
    if file_name.endswith("_roi_thickness.txt"):
        # Load the text file using pandas
        file_path = os.path.join(data_directory, file_name)
        df = pd.read_csv(file_path, sep='\t', header=None)

        # Extract the numerical values from the second row and remove the first entry (file name) and the second entry (sub-brick)
        data_vector = df.iloc[1, 2:].values.astype(float)

        data_length = len(data_vector)
        print(f"File: {file_name}, Data Length: {data_length}")

        # Extract FILE_ID from the complete file name
        file_id = file_name.split("_roi_thickness.txt")[0]

        # Find the participant's information based on FILE_ID in the CSV
        participant_row = participant_info_df.loc[participant_info_df['FILE_ID'] == file_id]

        # Extract age and gender from the participant's information
        age = participant_row['AGE_AT_SCAN'].values[0]
        gender = participant_row['SEX'].values[0]
        dx_group = participant_row['DX_GROUP'].values[0]

        # Store the data and participant information in the appropriate dictionary based on DX_GROUP
        if dx_group == 1:
            data_info_dict_autism[file_id] = {
                "data": data_vector,
                "age": age,
                "gender": gender
            }
        elif dx_group == 2:
            data_info_dict_no_autism[file_id] = {
                "data": data_vector,
                "age": age,
                "gender": gender
            }


File: CMU_a_0050653_roi_thickness.txt, Data Length: 97


In [11]:
data_info_dict_no_autism = data_info_dict_autism

The data downloaded in this way is in a 3D volume. I want to have the data as a vector. Therefore, I am doing the following:

In [12]:
#@title Execute when working with 3D data

# Directory containing your NIfTI files
#data_directory = "./data/anat_thickness/"

# Read the participant information from the CSV file
csv_file = "./data/participant_info/participant_info.csv"
participant_info_df = pd.read_csv(csv_file)

# Create a dictionary to store data and participant information
#data_info_dict = {}

# Loop through each NIfTI file
for file_name in os.listdir(data_directory):
    pass
    # Check if the file is a NIfTI file
    if file_name.endswith("_anat_thickness.nii.gz"):
        # Load the NIfTI file
        nifti_img = nib.load(os.path.join(data_directory, file_name))

        # Get the data as a NumPy array
        data_array = nifti_img.get_fdata()
        print("The 3D data has the shape of" ,data_array.shape)
        # Reshape to a single vector
        data_vector = data_array.ravel()

        # Extract FILE_ID from the complete NIfTI file name
        file_id = file_name.split("_anat_thickness.nii.gz")[0]

        # Find the participant's information based on FILE_ID in the CSV
        participant_row = participant_info_df.loc[participant_info_df['FILE_ID'] == file_id]

        # Extract age and gender from the participant's information
        age = participant_row['AGE_AT_SCAN'].values[0]
        gender = participant_row['SEX'].values[0]

        # Store the data and participant information in the dictionary
        data_info_dict[file_id] = {
            "data": data_vector,
            "age": age,
            "gender": gender
        }


Let's check if this all worked:

In [13]:
# Calculate overall statistics for the autism category
autism_data_lengths = [len(info["data"]) for info in data_info_dict_autism.values()]
total_autism_samples = len(autism_data_lengths)
average_autism_data_length = sum(autism_data_lengths) / total_autism_samples
min_autism_data_length = min(autism_data_lengths)
max_autism_data_length = max(autism_data_lengths)
std_autism_data_length = np.std(autism_data_lengths)
autism_ages = [info["age"] for info in data_info_dict_autism.values()]
average_autism_age = sum(autism_ages) / total_autism_samples
min_autism_age = min(autism_ages)
max_autism_age = max(autism_ages)
std_autism_age = np.std(autism_ages)
autism_genders = [info["gender"] for info in data_info_dict_autism.values()]
autism_male_count = autism_genders.count(1)
autism_female_count = autism_genders.count(2)

# Calculate overall statistics for the non-autism category
non_autism_data_lengths = [len(info["data"]) for info in data_info_dict_no_autism.values()]
total_non_autism_samples = len(non_autism_data_lengths)
average_non_autism_data_length = sum(non_autism_data_lengths) / total_non_autism_samples
min_non_autism_data_length = min(non_autism_data_lengths)
max_non_autism_data_length = max(non_autism_data_lengths)
std_non_autism_data_length = np.std(non_autism_data_lengths)
non_autism_ages = [info["age"] for info in data_info_dict_no_autism.values()]
average_non_autism_age = sum(non_autism_ages) / total_non_autism_samples
min_non_autism_age = min(non_autism_ages)
max_non_autism_age = max(non_autism_ages)
std_non_autism_age = np.std(non_autism_ages)
non_autism_genders = [info["gender"] for info in data_info_dict_no_autism.values()]
non_autism_male_count = non_autism_genders.count(1)
non_autism_female_count = non_autism_genders.count(2)

# Print the statistics for the autism category
print("Autism Data Statistics:")
print("Total Samples:", total_autism_samples)
print("Average Data Length:", average_autism_data_length)
print("Minimum Data Length:", min_autism_data_length)
print("Maximum Data Length:", max_autism_data_length)
print("Standard Deviation of Data Length:", std_autism_data_length)
print("")

print("Autism Age Statistics:")
print("Average Age:", average_autism_age)
print("Minimum Age:", min_autism_age)
print("Maximum Age:", max_autism_age)
print("Standard Deviation of Age:", std_autism_age)
print("")

print("Autism Gender Counts:")
print("Male Count:", autism_male_count)
print("Female Count:", autism_female_count)
print("")

# Print the statistics for the non-autism category
print("Non-Autism Data Statistics:")
print("Total Samples:", total_non_autism_samples)
print("Average Data Length:", average_non_autism_data_length)
print("Minimum Data Length:", min_non_autism_data_length)
print("Maximum Data Length:", max_non_autism_data_length)
print("Standard Deviation of Data Length:", std_non_autism_data_length)
print("")

print("Non-Autism Age Statistics:")
print("Average Age:", average_non_autism_age)
print("Minimum Age:", min_non_autism_age)
print("Maximum Age:", max_non_autism_age)
print("Standard Deviation of Age:", std_non_autism_age)
print("")

print("Non-Autism Gender Counts:")
print("Male Count:", non_autism_male_count)
print("Female Count:", non_autism_female_count)


Autism Data Statistics:
Total Samples: 1
Average Data Length: 97.0
Minimum Data Length: 97
Maximum Data Length: 97
Standard Deviation of Data Length: 0.0

Autism Age Statistics:
Average Age: 30.0
Minimum Age: 30.0
Maximum Age: 30.0
Standard Deviation of Age: 0.0

Autism Gender Counts:
Male Count: 1
Female Count: 0

Non-Autism Data Statistics:
Total Samples: 1
Average Data Length: 97.0
Minimum Data Length: 97
Maximum Data Length: 97
Standard Deviation of Data Length: 0.0

Non-Autism Age Statistics:
Average Age: 30.0
Minimum Age: 30.0
Maximum Age: 30.0
Standard Deviation of Age: 0.0

Non-Autism Gender Counts:
Male Count: 1
Female Count: 0


Next, I need to create a dataloader.

In [14]:
class CombinedDataset(Dataset):
    def __init__(self, autism_data_info, no_autism_data_info):
        self.autism_data_info = autism_data_info
        self.no_autism_data_info = no_autism_data_info
        self.autism_file_ids = list(self.autism_data_info.keys())
        self.no_autism_file_ids = list(self.no_autism_data_info.keys())

    def __len__(self):
        return max(len(self.autism_file_ids), len(self.no_autism_file_ids))

    def __getitem__(self, index):
        autism_index = index % len(self.autism_file_ids)
        no_autism_index = index % len(self.no_autism_file_ids)

        autism_file_id = self.autism_file_ids[autism_index]
        no_autism_file_id = self.no_autism_file_ids[no_autism_index]

        autism_data = torch.tensor(self.autism_data_info[autism_file_id]["data"], dtype=torch.float32)
        autism_age = torch.tensor(self.autism_data_info[autism_file_id]["age"], dtype=torch.float32)
        autism_gender = torch.tensor(self.autism_data_info[autism_file_id]["gender"], dtype=torch.float32)

        no_autism_data = torch.tensor(self.no_autism_data_info[no_autism_file_id]["data"], dtype=torch.float32)
        no_autism_age = torch.tensor(self.no_autism_data_info[no_autism_file_id]["age"], dtype=torch.float32)
        no_autism_gender = torch.tensor(self.no_autism_data_info[no_autism_file_id]["gender"], dtype=torch.float32)

        return (autism_data, autism_age, autism_gender), (no_autism_data, no_autism_age, no_autism_gender)

# Create the combined dataset
combined_dataset = CombinedDataset(data_info_dict_autism, data_info_dict_no_autism)

# Create the dataloader
batch_size = 64
shuffle = True
combined_dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=shuffle)


## Model specifications

In the following I am specifiying the model. I am roughly orienting myself around a paper from Anglinkas, Hartshorne & Anzellotti (2022).

### Defining utility functions

Firstly, I am defining the loss function.
The loss will be computed as the sum of the BCE-Loss, as well as the KL-divergence terms.

* MSE loss: Incoming

* Cross Entropy: Incoming

* Kullback-Leibler divergence (Kullback & Leibler, 1951) This is a measure for the difference between two distributions. I.e. "how much do they diverge" from each other, how much are they different to each other. The introduction of this term into the final loss leads my model to optimize not only if the precited categories are correct and so on, but also how high the difference between the prior distribution and teh latent variables are. The prior distribution in my case is an isotropic gaussian.
  * Why is this desirable? The latent variables and the sampling process should be somewhat controlled. This divergence regulates this.


I have also attempted to regulate that a loss is only completed with the KL divergence from the second encoder if that encoder was used.

In [15]:
def final_loss(MSE, CE, z_mu, z_logvar, s_mu, s_logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: reconstruction loss
    :param z_mu: mean from the latent vector of encoder_z
    :param z_logvar: log variance from the latent vector of encoder_z
    :param s_mu: mean from the latent vector of encoder_s (optional)
    :param s_logvar: log variance from the latent vector of encoder_s (optional)
    """
    mse_loss = MSE
    cross_entropy = CE
    KLD_z = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
    if s_mu is not None and s_logvar is not None:
        KLD_s = -0.5 * torch.sum(1 + s_logvar - s_mu.pow(2) - s_logvar.exp())
        return mse_loss + KLD_z + KLD_s + cross_entropy
    else:
        return mse_loss + KLD_z + cross_entropy


Here is the training loop. This model is supposed to achieve multiple things:

* Train the cVAE using the MSE loss.
* Incoming.


In [31]:
from tqdm import tqdm

def train(model, dataloader, dataset, device, optimizer, criterion, criterion_classifier):
    model.train()
    running_loss_autism = 0.0
    running_loss_no_autism = 0.0
    running_age_loss = 0.0
    running_gender_loss = 0.0
    counter = 0

    total_batches = len(dataset) // dataloader.batch_size

    for i, ((autism_data, autism_age, autism_gender), (no_autism_data, no_autism_age, no_autism_gender)) in tqdm(enumerate(dataloader), total=total_batches):
        autism_data = autism_data.to(device)
        no_autism_data = no_autism_data.to(device)

        autism_age = autism_age.to(device)
        autism_gender = autism_gender.to(device)
        no_autism_age = no_autism_age.to(device)
        no_autism_gender = no_autism_gender.to(device)

        optimizer.zero_grad()

        # Get the model outputs
        z_mean, z_log_var, s_mean, s_log_var, z_mean_no_autism, z_log_var_no_autism, reconstructed_data_autism, reconstructed_data_no_autism, class_autism_age, class_autism_gender, class_no_autism_age, class_no_autism_gender = model(autism_data, no_autism_data)

        # Section for the autism images
        bce_loss_autism = criterion(reconstructed_data_autism, autism_data)
        loss_autism = final_loss(bce_loss_autism, z_mean, z_log_var, s_mean, s_log_var)
        running_loss_autism += loss_autism.item()

        # Section for the no_autism images
        bce_loss_no_autism = criterion(reconstructed_data_no_autism, no_autism_data)
        s_mean_no_autism, s_log_var_no_autism = None, None
        loss_no_autism = final_loss(bce_loss_no_autism, z_mean_no_autism, z_log_var_no_autism, s_mean_no_autism, s_log_var_no_autism)
        running_loss_no_autism += loss_no_autism.item()

        # Calculate classifier losses for age and gender predictions
        age_loss_autism = criterion_classifier(class_autism_age, autism_age.unsqueeze(1))
        gender_loss_autism = criterion_classifier(class_autism_gender, autism_gender)

        age_loss_no_autism = criterion_classifier(class_no_autism_age, no_autism_age.unsqueeze(1))
        gender_loss_no_autism = criterion_classifier(class_no_autism_gender, no_autism_gender)

        # Accumulate classifier losses
        running_age_loss += (age_loss_autism.item() + age_loss_no_autism.item())
        running_gender_loss += (gender_loss_autism.item() + gender_loss_no_autism.item())

        # Total loss (you can weigh the classifier losses with appropriate coefficients if needed)
        total_loss = loss_autism + loss_no_autism + age_loss_autism + gender_loss_autism + age_loss_no_autism + gender_loss_no_autism
        total_loss.backward()

        optimizer.step()
        counter += len(autism_data) + len(no_autism_data)

    train_loss_autism = running_loss_autism / counter
    train_loss_no_autism = running_loss_no_autism / counter
    train_age_loss = running_age_loss / counter
    train_gender_loss = running_gender_loss / counter

    return train_loss_autism, train_loss_no_autism, train_age_loss, train_gender_loss


## Model specification

These values still need to be adapted for the current model.

In [33]:
input_dimension = 97 # The numer of features
indermediate_dim = 128
latent_dim = 4 # latent dimension for sampling

lr = 0.001



Next I want to define the contrastive variational autoencoder. While doing so, I am defining seperate encoders, to make it easier to later introduce other encoders. I am orienting myself on an cVAE I have written in the past.

As the paper from Aglinskas, Hartshorne and Anzellotti (2022) I mentioned, the network will have only a few layers.

A few things I will probably have to change - I do not know how many channels the data will end up having. therefore I am using one, assuming it only has one.

In [37]:
class EncoderNS(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(EncoderNS, self).__init__()
        self.linear1 = nn.Linear(input_dimension, 64)
        self.linear2 = nn.Linear(64, 32)
        self.linear3 = nn.Linear(32, 4)
        self.ns_fc_mean = nn.Linear(latent_dim, latent_dim)
        self.ns_fc_log_var = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, batch_size):
        h = F.relu(self.linear1(x))
        h = F.relu(self.linear2(h))
        h = F.relu(self.linear3(h))
        ns_mean = self.ns_fc_mean(h)
        ns_log_var = self.ns_fc_log_var(h)
        return ns_mean, ns_log_var


class EncoderS(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(EncoderS, self).__init__()
        self.linear1 = nn.Linear(input_dimension, 64)
        self.linear2 = nn.Linear(64, 32)
        self.linear3 = nn.Linear(32, 4)
        self.s_fc_mean = nn.Linear(latent_dim, latent_dim)
        self.s_fc_log_var = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, batch_size):
        h = F.relu(self.linear1(x))
        h = F.relu(self.linear2(h))
        h = F.relu(self.linear3(h))
        s_mean = self.s_fc_mean(h)
        s_log_var = self.s_fc_log_var(h)
        return s_mean, s_log_var

class Decoder(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(Decoder, self).__init__()
        self.linear_decoder_1 = nn.Linear(latent_dim, 32)
        self.linear_decoder_2 = nn.Linear(32,64)
        self.linear_decoder_3 = nn.Linear(64, input_dimension)

    def forward(self, zs, batch_size):
        h_output = F.relu(self.linear_decoder_1(zs))
        h_output = F.relu(self.linear_decoder_2(h_output))
        output = F.relu(self.linear_decoder_3(h_output))
        return output

class Classifier(nn.Module):
    def __init__(self, latent_dim):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(latent_dim, latent_dim // 2)
        self.fc_age = nn.Linear(latent_dim // 2, 1)
        self.fc_gender = nn.Linear(latent_dim // 2, 2)

    def forward(self, z):
        x = self.fc1(z)
        age_prediction = self.fc_age(x)
        gender_prediction = torch.sigmoid(self.fc_gender(x))  # Apply sigmoid activation for binary gender prediction
        return age_prediction, gender_prediction

class cVAE(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(cVAE, self).__init__()
        self.encoder_z = EncoderNS(input_dimension, latent_dim)
        self.encoder_s = EncoderS(input_dimension, latent_dim)
        self.decoder = Decoder(input_dimension, latent_dim)
        self.overlay_status = None
        self.classifier = Classifier(latent_dim)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mean + epsilon * std

    def forward(self, autism, no_autism):
        batch_size = autism.size(0)
        z_mean, z_log_var = self.encoder_z(autism, batch_size)
        z = self.reparameterize(z_mean, z_log_var)
        s_mean, s_log_var = self.encoder_s(autism, batch_size)
        s = self.reparameterize(s_mean, s_log_var)
        zs = torch.cat([z, s], dim=1)

        reconstructed_data_autism = self.decoder(zs, batch_size)

        z_mean_no_autism, z_log_var_no_autism = self.encoder_z(no_autism, batch_size)
        z_no_autism = self.reparameterize(z_mean_no_autism, z_log_var_no_autism)
        z_empty = torch.zeros(z_no_autism.shape)
        z_no_autism_0 = torch.cat([z_no_autism, z_empty], dim=1)
        reconstructed_data_no_autism = self.decoder(z_no_autism_0, batch_size)

        class_autism_age, class_autism_gender = self.classifier(z)  # Assuming z is the latent variable after concatenating s and z
        class_no_autism_age, class_no_autism_gender = self.classifier(z_no_autism_0)  # Using the version with 0s to have equal lengths of the latent vectors.

        return z_mean, z_log_var, s_mean, s_log_var, z_mean_no_autism, z_log_var_no_autism, reconstructed_data_autism, reconstructed_data_no_autism, class_autism_age, class_autism_gender, class_no_autism_age, class_no_autism_gender

And finally the training loop - note that I have yet to define the validation function:

In [38]:
++++model = cVAE(input_dimension = 97, latent_dim=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
classifier_criterion = nn.CrossEntropyLoss

train_loss_list = []  # List to store train losses
val_loss_list = []  # List to store validation losses

num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} of {num_epochs}")
    # Train the model
    train_loss_target, train_loss_bg = train(model, combined_dataloader,combined_dataset, device, optimizer, criterion, classifier_criterion)

    # Validate the model
    #val_loss, recon_images = validate(model, overlaid_dataloader, overlaid_dataset, device, criterion, classifier_criterion)

    # Appending the loss values to a list to allow for visualizations:

    train_loss_list.append(train_loss_target)
    #val_loss_list.append(val_loss)


    # Print the losses
    print(f"Train Loss: {train_loss_target:.4f}, Val Loss: {val_loss:.4f}")
    #print(f"Train Loss for the background: {train_loss_bg:.4f}, Val Loss: {val_loss:.4f}")

print('TRAINING COMPLETE')


Epoch 1 of 10


0it [00:00, ?it/s]


RuntimeError: ignored