<a href="https://colab.research.google.com/github/tchaase/cVAE_autism/blob/main/code/cVAE_autism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Contrastive Variational Autoencoder for the ABIDE Data Set

Author - Tobias Haase

## Imports

Firstly I am importaing the necessary modules here, that I will use within the following.



In [11]:
import torch  # The main PyTorch library for tensor computations and neural network operations

import torch.nn as nn  # Provides various neural network layers and functionalities
import torch.nn.functional as F  # Provides functional interfaces to common operations (e.g., activation functions)
import torch.optim as optim  # Contains various optimization algorithms (e.g., SGD, Adam)

import torchvision  # A PyTorch library for computer vision tasks
import torchvision.transforms as transforms  # Provides common image transformations (e.g., resizing, normalization)
from torchvision.transforms import ToTensor  # Transforms PIL images to tensors
from torch.utils.data import Dataset, DataLoader  # Provides tools for creating custom datasets and data loaders
import torch.nn as nn

import numpy as np  # NumPy library for numerical computations and array operations
import matplotlib  # Matplotlib library for data visualization
import matplotlib.pyplot as plt  # Matplotlib's pyplot module for creating plots
from tqdm import tqdm  # Progress bar library for tracking iterations

import os
import requests
import nibabel as nib
import numpy as np
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
device

device(type='cpu')

Next, I am loading the project's data. To load the data, I am using CyberDuck, and I am loaded the already preprocessed cortical thickness data.

Firstly, I need to install **CyberDuck**:


In [None]:
!echo -e "deb https://s3.amazonaws.com/repo.deb.cyberduck.io stable main" | sudo tee /etc/apt/sources.list.d/cyberduck.list > /dev/null
!sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys FE7097963FEFBE72
!sudo apt-get update
!sudo apt-get install duck

Executing: /tmp/apt-key-gpghome.UDxYlqxJL7/gpg.1.sh --keyserver keyserver.ubuntu.com --recv-keys FE7097963FEFBE72
gpg: key F7FAE1F32DA69515: public key "Cyberduck <feedback@cyberduck.io>" imported
gpg: Total number processed: 1
gpg:               imported: 1
Hit:1 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:2 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [119 kB]
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [110 kB]
Hit:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:5 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:6 https://s3.amazonaws.com/repo.deb.cyberduck.io stable InRelease [3,245 B]
Get:7 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease [18.1 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [109 kB]
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net

In [None]:
#!ls ./data/anat_thickness/
!rm -rf ./data

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


Next, let's download the file with the participant info:

In [13]:
# URL to download the CSV file!
csv_url = "https://s3.amazonaws.com/fcp-indi/data/Projects/ABIDE_Initiative/Phenotypic_V1_0b_preprocessed1.csv"  # Replace with the actual URL

# Directory to store the CSV file
data_directory = "./data/participant_info"

# Create the directory if it does not exist
os.makedirs(data_directory, exist_ok=True)

# File path to save the CSV file
csv_file_path = os.path.join(data_directory, "participant_info.csv")

# Download the CSV file
response = requests.get(csv_url)
if response.status_code == 200:
    with open(csv_file_path, "wb") as f:
        f.write(response.content)
    print("CSV file downloaded successfully.")
else:
    print("Failed to download the CSV file.")


CSV file downloaded successfully.


Next I am downloading the transformation templates.

In [23]:
!pip install ants

import pandas as pd
import os
import shutil
import urllib.request
from nipype.interfaces.ants import ApplyTransforms

# Define paths
data_directory = "./data/participant_info"
drive_directory = './drive/MyDrive/MasterThesisData'
temp_directory = './temp_data'  # Temporary directory

# Create directories if they don't exist
os.makedirs(drive_directory, exist_ok=True)
os.makedirs(temp_directory, exist_ok=True)

# Load the CSV file
csv_file_path = os.path.join(data_directory, "participant_info.csv")
data = pd.read_csv(csv_file_path)

failed_download = []

# Loop through the association list
for index, row in data.iterrows():
    file_id = row['FILE_ID']
    sub_id = row['SUB_ID']

    try:
        # Construct the MRI data file name
        mri_data_file = f"{file_id}_anat_thickness.nii.gz"
        mri_data_path = os.path.join(temp_directory, mri_data_file)

        # Download the MRI data file to the temporary directory
        mri_url = f"https://fcp-indi.s3.amazonaws.com/data/Projects/ABIDE_Initiative/Outputs/ants/anat_thickness/{mri_data_file}"
        urllib.request.urlretrieve(mri_url, mri_data_path)

        # Construct the transformation template file name
        template_file_name = f'sub-00{sub_id}_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5'
        template_destination = os.path.join(temp_directory, template_file_name)

        # Download the transformation template
        template_url = f"https://fcp-indi.s3.amazonaws.com/data/Projects/ABIDE/Outputs/fmriprep/fmriprep/sub-00{sub_id}/anat/{template_file_name}"
        urllib.request.urlretrieve(template_url, template_destination)

        # Define paths for the transformed output
        transformed_output_path = os.path.join(drive_directory, f"{file_id}_transformed.nii.gz")

        from nipype import config, logging
        from nipype.interfaces.base import CommandLine

        # Set the execution plugin to CommandLine
        config.enable_debug_mode()
        config.enable_provenance()
        logging.update_logging(config)

        # Define the antsApplyTransforms interface
        at = ApplyTransforms()
        at.inputs.input_image = mri_data_path
        at.inputs.reference_image = '/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz'  # Replace with the actual path
        at.inputs.transforms = template_destination
        at.inputs.output_image = transformed_output_path

        # Execute the antsApplyTransforms command
        cmdline = at.cmdline
        cl = CommandLine(command=cmdline)
        cl.run()

        # Clean up: Delete the input data and template
        os.remove(mri_data_path)
        os.remove(template_destination)

        print(f"Transformed and saved for FILE_ID: {file_id}")

    except Exception as e:
        print(f"Failed for FILE_ID: {file_id}. Error: {e}")

print("Processing completed.")


Failed for FILE_ID: no_filename. Error: HTTP Error 404: Not Found
230811-19:43:35,893 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:35,901 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:35,905 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050003_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050003_anat_thickness.nii.gz


230811-19:43:35,909 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050003. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:37,676 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:37,681 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:37,686 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050004_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050004_anat_thickness.nii.gz


230811-19:43:37,689 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050004. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:40,85 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:40,98 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:40,107 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050005_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050005_anat_thickness.nii.gz


230811-19:43:40,115 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050005. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:41,882 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:41,893 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:41,900 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050006_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050006_anat_thickness.nii.gz


230811-19:43:41,902 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050006. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:44,278 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:44,288 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:44,297 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050007_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050007_anat_thickness.nii.gz


230811-19:43:44,306 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050007. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:47,1 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:47,7 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:47,13 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050008_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050008_anat_thickness.nii.gz


230811-19:43:47,16 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050008. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:49,764 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:49,777 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:49,784 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050009_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050009_anat_thickness.nii.gz


230811-19:43:49,790 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050009. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.
230811-19:43:52,306 nipype.interface DEBUG:
	 default_value_0.0


DEBUG:nipype.interface:default_value_0.0


230811-19:43:52,313 nipype.interface DEBUG:
	 float_False


DEBUG:nipype.interface:float_False


230811-19:43:52,319 nipype.interface DEBUG:
	 input_image_temp_data/Pitt_0050010_anat_thickness.nii.gz


DEBUG:nipype.interface:input_image_temp_data/Pitt_0050010_anat_thickness.nii.gz


230811-19:43:52,323 nipype.interface DEBUG:
	 reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


DEBUG:nipype.interface:reference_image_/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz


Failed for FILE_ID: Pitt_0050010. Error: No command "antsApplyTransforms" found on host 82a33ac3693d. Please check that the corresponding package is installed.


KeyboardInterrupt: ignored

In [None]:
!ls ./data/anat_thickness/

Before I can continue I need to create values for the cortical thickness via an atlas to make the resutls a bit better understandable and the model slimmer. This way I can also create a deeper model?

In [None]:
# Load the atlas
from nilearn.datasets import fetch_atlas_destrieux_2009
cortical_thickness_atlas_destrieux = fetch_atlas_destrieux_2009(lateralized = True)

# From this, we can also export the labels that we can use for later visualization.
labels = cortical_thickness_atlas_destrieux.labels



Before we can actually transform the images we need to get them into the same space. Originally, the pictures werent transformed into any space and are therefore in their natice space. We are tranforming them into the same space the template is in, i.e. MNI152NLin2009cAsym

In [6]:
!pip install templateflow
from templateflow import api as tflow
mni152 = tflow.get('MNI152NLin2009cAsym', desc=None, resolution=1,
                    suffix='T1w', extension='nii.gz')
mni152

Collecting templateflow
  Downloading templateflow-23.0.0-py3-none-any.whl (455 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m455.4/455.4 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pybids>=0.15.2 (from templateflow)
  Downloading pybids-0.16.1-py3-none-any.whl (14.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.9/14.9 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
Collecting formulaic<0.6,>=0.2.4 (from pybids>=0.15.2->templateflow)
  Downloading formulaic-0.5.2-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting bids-validator>=1.11 (from pybids>=0.15.2->templateflow)
  Downloading bids_validator-1.12.0-py2.py3-none-any.whl (21 kB)
Collecting num2words>=0.5.5 (from pybids>=0.15.2->templateflow)
  Downloading num2words-0.5.12-py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.2/1

Downloading https://templateflow.s3.amazonaws.com/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz
100%|██████████| 13.7M/13.7M [00:00<00:00, 45.6MB/s]


PosixPath('/root/.cache/templateflow/tpl-MNI152NLin2009cAsym/tpl-MNI152NLin2009cAsym_res-01_T1w.nii.gz')

Next, I need to transform the images. For this, the Advanced Normalisation Tool needs to be used.

In [9]:
#!pip install nipype
from nipype.interfaces.ants import ANTS
from nipype.interfaces.ants import ApplyTransforms



at = ApplyTransforms()
at.inputs.input_image = 'moving1.nii'
at.inputs.reference_image = 'fixed1.nii'
at.inputs.transforms = 'identity'

at.cmdline
'antsApplyTransforms --default-value 0 --float 0 --input moving1.nii --interpolation Linear --output moving1_trans.nii --reference-image fixed1.nii --transform identity'



AttributeError: ignored

Now, I have two options currently, eihter I load the participants data via the 3d image and overlay an atlas manually, or I use predefined labels.

In [None]:
# Directory containing your text files
data_directory = "./data/roi_thickness/"

# Read the participant information from the CSV file
csv_file = "./data/participant_info/participant_info.csv"
participant_info_df = pd.read_csv(csv_file)

# Create dictionaries to store data and participant information for autism and non-autism participants
data_info_dict_autism = {}
data_info_dict_no_autism = {}

# Loop through each text file
for file_name in os.listdir(data_directory):
    # Check if the file is a text file
    if file_name.endswith("_roi_thickness.txt"):
        # Loop through each SUB_ID in the participant info
        for sub_id in participant_info_df['SUB_ID']:
            sub_id_str = str(sub_id)  # Convert SUB_ID to string
            if sub_id_str in file_name:
                # Load the text file using pandas
                file_path = os.path.join(data_directory, file_name)
                df = pd.read_csv(file_path, sep='\t', header=None)

                # Extract the numerical values from the second row and remove the first entry (file name) and the second entry (sub-brick)
                data_vector = df.iloc[1, 2:].values.astype(float)

                data_length = len(data_vector)
                print(f"File: {file_name}, Data Length: {data_length}")

                # Find the participant's information based on SUB_ID
                participant_row = participant_info_df.loc[participant_info_df['SUB_ID'] == sub_id]
                if not participant_row.empty:
                    # Extract age and gender from the participant's information
                    age = participant_row['AGE_AT_SCAN'].values[0]
                    gender = participant_row['SEX'].values[0] - 1
                    dx_group = participant_row['DX_GROUP'].values[0]

                    # Store the data and participant information in the appropriate dictionary based on DX_GROUP
                    if dx_group == 1:
                        data_info_dict_autism[sub_id_str] = {
                            "data": data_vector,
                            "age": age,
                            "gender": gender
                        }
                    elif dx_group == 2:
                        data_info_dict_no_autism[sub_id_str] = {
                            "data": data_vector,
                            "age": age,
                            "gender": gender
                        }
                else:
                    print(f"No participant information found for {sub_id}")


File: Yale_0050616_roi_thickness.txt, Data Length: 97
File: Pitt_0050002_roi_thickness.txt, Data Length: 97
File: Stanford_0051181_roi_thickness.txt, Data Length: 97
File: UM_1_0050338_roi_thickness.txt, Data Length: 97
File: UCLA_1_0051225_roi_thickness.txt, Data Length: 97
File: Trinity_0050271_roi_thickness.txt, Data Length: 97
File: CMU_a_0050659_roi_thickness.txt, Data Length: 97
File: Caltech_0051475_roi_thickness.txt, Data Length: 97
File: NYU_0051084_roi_thickness.txt, Data Length: 97
File: OHSU_0050147_roi_thickness.txt, Data Length: 97
File: UM_2_0050385_roi_thickness.txt, Data Length: 97
File: Yale_0050603_roi_thickness.txt, Data Length: 97
File: UM_1_0050364_roi_thickness.txt, Data Length: 97
File: MaxMun_c_0051336_roi_thickness.txt, Data Length: 97
File: Olin_0050111_roi_thickness.txt, Data Length: 97
File: UM_1_0050371_roi_thickness.txt, Data Length: 97
File: NYU_0051032_roi_thickness.txt, Data Length: 97
File: UCLA_2_0051296_roi_thickness.txt, Data Length: 97
File: NYU_0

In [None]:
#data_info_dict_no_autism = data_info_dict_autism

The data downloaded in this way is in a 3D volume. I want to have the data as a vector. Therefore, I am doing the following:

In [None]:
#@title Execute when working with 3D data

# Directory containing your NIfTI files
#data_directory = "./data/anat_thickness/"

# Read the participant information from the CSV file
csv_file = "./data/participant_info/participant_info.csv"
participant_info_df = pd.read_csv(csv_file)

# Create a dictionary to store data and participant infogender = torch.tensor(participant_row['SEX'].values[0] - 1rmation
#data_info_dict = {}

# Loop through each NIfTI file
for file_name in os.listdir(data_directory):
    pass
    # Check if the file is a NIfTI file
    if file_name.endswith("_anat_thickness.nii.gz"):
        # Load the NIfTI file
        nifti_img = nib.load(os.path.join(data_directory, file_name))

        # Get the data as a NumPy array
        data_array = nifti_img.get_fdata()
        print("The 3D data has the shape of" ,data_array.shape)
        # Reshape to a single vector
        data_vector = data_array.ravel()

        # Extract FILE_ID from the complete NIfTI file name
        file_id = file_name.split("_anat_thickness.nii.gz")[0]

        # Find the participant's information based on FILE_ID in the CSV
        participant_row = participant_info_df.loc[participant_info_df['FILE_ID'] == file_id]

        # Extract age and gender from the participant's information
        age = participant_row['AGE_AT_SCAN'].values[0]
        gender = participant_row['SEX'].values[0]

        # Store the data and participant information in the dictionary
        data_info_dict[file_id] = {
            "data": data_vector,
            "age": age,
            "gender": gender
        }


Lets see if the number of files I loaded matches the numbers of files in my dictionary:

In [None]:
! cd ./data & du -a | cut -d/ -f2 | sort | uniq -c | sort -nr

   1107 data
     18 .config
      7 sample_data
      1 60560	.


Let's check if this all worked:

In [None]:
# Calculate overall statistics for the autism category
autism_data_lengths = [len(info["data"]) for info in data_info_dict_autism.values()]
total_autism_samples = len(autism_data_lengths)
average_autism_data_length = sum(autism_data_lengths) / total_autism_samples
min_autism_data_length = min(autism_data_lengths)
max_autism_data_length = max(autism_data_lengths)
std_autism_data_length = np.std(autism_data_lengths)
autism_ages = [info["age"] for info in data_info_dict_autism.values()]
average_autism_age = sum(autism_ages) / total_autism_samples
min_autism_age = min(autism_ages)
max_autism_age = max(autism_ages)
std_autism_age = np.std(autism_ages)
autism_genders = [info["gender"] for info in data_info_dict_autism.values()]
# Calculate gender counts for the autism category
autism_male_count = autism_genders.count(0)
autism_female_count = autism_genders.count(1)

# Calculate overall statistics for the non-autism category
non_autism_data_lengths = [len(info["data"]) for info in data_info_dict_no_autism.values()]
total_non_autism_samples = len(non_autism_data_lengths)
average_non_autism_data_length = sum(non_autism_data_lengths) / total_non_autism_samples
min_non_autism_data_length = min(non_autism_data_lengths)
max_non_autism_data_length = max(non_autism_data_lengths)
std_non_autism_data_length = np.std(non_autism_data_lengths)
non_autism_ages = [info["age"] for info in data_info_dict_no_autism.values()]
average_non_autism_age = sum(non_autism_ages) / total_non_autism_samples
min_non_autism_age = min(non_autism_ages)
max_non_autism_age = max(non_autism_ages)
std_non_autism_age = np.std(non_autism_ages)
non_autism_genders = [info["gender"] for info in data_info_dict_no_autism.values()]
# Calculate gender counts for the non-autism category
non_autism_male_count = non_autism_genders.count(0)
non_autism_female_count = non_autism_genders.count(1)





# Print the statistics for the autism category
print("Autism Data Statistics:")
print("Total Samples:", total_autism_samples)
print("Average Data Length:", average_autism_data_length)
print("Minimum Data Length:", min_autism_data_length)
print("Maximum Data Length:", max_autism_data_length)
print("Standard Deviation of Data Length:", std_autism_data_length)
print("")

print("Autism Age Statistics:")
print("Average Age:", average_autism_age)
print("Minimum Age:", min_autism_age)
print("Maximum Age:", max_autism_age)
print("Standard Deviation of Age:", std_autism_age)
print("")

print("Autism Gender Counts:")
print("Male Count:", autism_male_count)
print("Female Count:", autism_female_count)
print("")

# Print the statistics for the non-autism category
print("Non-Autism Data Statistics:")
print("Total Samples:", total_non_autism_samples)
print("Average Data Length:", average_non_autism_data_length)
print("Minimum Data Length:", min_non_autism_data_length)
print("Maximum Data Length:", max_non_autism_data_length)
print("Standard Deviation of Data Length:", std_non_autism_data_length)
print("")

print("Non-Autism Age Statistics:")
print("Average Age:", average_non_autism_age)
print("Minimum Age:", min_non_autism_age)
print("Maximum Age:", max_non_autism_age)
print("Standard Deviation of Age:", std_non_autism_age)
print("")

print("Non-Autism Gender Counts:")
print("Male Count:", non_autism_male_count)
print("Female Count:", non_autism_female_count)


Autism Data Statistics:
Total Samples: 531
Average Data Length: 97.0
Minimum Data Length: 97
Maximum Data Length: 97
Standard Deviation of Data Length: 0.0

Autism Age Statistics:
Average Age: 17.066997551789047
Minimum Age: 7.0
Maximum Age: 64.0
Standard Deviation of Age: 8.408176785458945

Autism Gender Counts:
Male Count: 467
Female Count: 64

Non-Autism Data Statistics:
Total Samples: 571
Average Data Length: 97.0
Minimum Data Length: 97
Maximum Data Length: 97
Standard Deviation of Data Length: 0.0

Non-Autism Age Statistics:
Average Age: 17.102401576182142
Minimum Age: 6.47
Maximum Age: 56.2
Standard Deviation of Age: 7.719682046137984

Non-Autism Gender Counts:
Male Count: 472
Female Count: 99


Next, I need to create a dataloader.

In [None]:
class CombinedDataset(Dataset):
    def __init__(self, autism_data_info, no_autism_data_info):
        self.autism_data_info = autism_data_info
        self.no_autism_data_info = no_autism_data_info
        self.autism_file_ids = list(self.autism_data_info.keys())
        self.no_autism_file_ids = list(self.no_autism_data_info.keys())

    def __len__(self):
        return max(len(self.autism_file_ids), len(self.no_autism_file_ids))

    def __getitem__(self, index):
        autism_index = index % len(self.autism_file_ids)
        no_autism_index = index % len(self.no_autism_file_ids)

        autism_file_id = self.autism_file_ids[autism_index]
        no_autism_file_id = self.no_autism_file_ids[no_autism_index]

        autism_data = torch.tensor(self.autism_data_info[autism_file_id]["data"], dtype=torch.float32)
        autism_age = torch.tensor(self.autism_data_info[autism_file_id]["age"], dtype=torch.float32)
        autism_gender = torch.tensor(self.autism_data_info[autism_file_id]["gender"], dtype=torch.float32)

        no_autism_data = torch.tensor(self.no_autism_data_info[no_autism_file_id]["data"], dtype=torch.float32)
        no_autism_age = torch.tensor(self.no_autism_data_info[no_autism_file_id]["age"], dtype=torch.float32)
        no_autism_gender = torch.tensor(self.no_autism_data_info[no_autism_file_id]["gender"], dtype=torch.float32)

        return (autism_data, autism_age, autism_gender), (no_autism_data, no_autism_age, no_autism_gender)

# Create the combined dataset
combined_dataset = CombinedDataset(data_info_dict_autism, data_info_dict_no_autism)

# Create the dataloader
batch_size = 64
shuffle = True
combined_dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=shuffle)


## Model specifications

In the following I am specifiying the model. I am roughly orienting myself around a paper from Anglinkas, Hartshorne & Anzellotti (2022).

### Defining utility functions

Firstly, I am defining the loss function.
The loss will be computed as the sum of the BCE-Loss, as well as the KL-divergence terms.

* MSE loss: Incoming

* Cross Entropy: Incoming

* Kullback-Leibler divergence (Kullback & Leibler, 1951) This is a measure for the difference between two distributions. I.e. "how much do they diverge" from each other, how much are they different to each other. The introduction of this term into the final loss leads my model to optimize not only if the precited categories are correct and so on, but also how high the difference between the prior distribution and teh latent variables are. The prior distribution in my case is an isotropic gaussian.
  * Why is this desirable? The latent variables and the sampling process should be somewhat controlled. This divergence regulates this.


I have also attempted to regulate that a loss is only completed with the KL divergence from the second encoder if that encoder was used.

In [None]:
def final_loss(MSE, CE, MSE_age, z_mu, z_logvar, s_mu, s_logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: reconstruction loss
    :param z_mu: mean from the latent vector of encoder_z
    :param z_logvar: log variance from the latent vector of encoder_z
    :param s_mu: mean from the latent vector of encoder_s (optional)
    :param s_logvar: log variance from the latent vector of encoder_s (optional)
    """
    mse_loss = MSE
    mse_age = MSE_age
    cross_entropy = CE
    KLD_z = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp())
    if s_mu is not None and s_logvar is not None:
        KLD_s = -0.5 * torch.sum(1 + s_logvar - s_mu.pow(2) - s_logvar.exp())
        return mse_loss + KLD_z + KLD_s + cross_entropy + mse_age
    else:
        return mse_loss + KLD_z + cross_entropy + mse_age


Here is the training loop. This model is supposed to achieve multiple things:

* Train the cVAE using the MSE loss.
* Incoming.


In [None]:
from tqdm import tqdm

def train(model, dataloader, dataset, device, optimizer, criterion, criterion_classifier):
    model.train()
    running_loss_autism = 0.0
    running_loss_no_autism = 0.0
    running_age_loss_autism = 0.0
    running_gender_loss_autism = 0.0
    running_age_loss_no_autism = 0.0
    running_gender_loss_no_autism = 0.0
    counter = 0

    total_batches = len(dataset) // dataloader.batch_size

    for i, ((autism_data, autism_age, autism_gender), (no_autism_data, no_autism_age, no_autism_gender)) in tqdm(enumerate(dataloader), total=total_batches):
        autism_data = autism_data.to(device)
        no_autism_data = no_autism_data.to(device)

        autism_age = autism_age.to(device)
        autism_gender = autism_gender.to(device)

        no_autism_age = no_autism_age.to(device)
        no_autism_gender = no_autism_gender.to(device)

        optimizer.zero_grad()

        # Get the model outputs
        z_mean, z_log_var, s_mean, s_log_var, z_mean_no_autism, z_log_var_no_autism, reconstructed_data_autism, reconstructed_data_no_autism, class_autism_age, class_autism_gender, class_no_autism_age, class_no_autism_gender = model(autism_data, no_autism_data)

        # Calculate classifier losses for age and gender predictions
        age_loss_autism = criterion(class_autism_age, autism_age.unsqueeze(1))
        gender_loss_autism = criterion_classifier(class_autism_gender, autism_gender.unsqueeze(1))

        age_loss_no_autism = criterion(class_no_autism_age, no_autism_age.unsqueeze(1))
        gender_loss_no_autism = criterion_classifier(class_no_autism_gender, no_autism_gender.unsqueeze(1))

        # Section for the autism images
        bce_loss_autism = criterion(reconstructed_data_autism, autism_data)
        loss_autism = final_loss(bce_loss_autism,  gender_loss_autism, age_loss_autism, z_mean, z_log_var, s_mean, s_log_var)
        running_loss_autism += loss_autism.item()
        running_age_loss_autism += age_loss_autism.item()
        running_gender_loss_autism += gender_loss_autism.item()

        # Section for the no_autism images
        bce_loss_no_autism = criterion(reconstructed_data_no_autism, no_autism_data)
        s_mean_no_autism, s_log_var_no_autism = None, None
        loss_no_autism = final_loss(bce_loss_no_autism,  gender_loss_no_autism, age_loss_no_autism, z_mean_no_autism, z_log_var_no_autism, s_mean_no_autism, s_log_var_no_autism)
        running_loss_no_autism += loss_no_autism.item()
        running_age_loss_no_autism += age_loss_no_autism.item()
        running_gender_loss_no_autism += gender_loss_no_autism.item()

        # Total loss
        loss_no_autism.backward()
        loss_autism.backward()

        optimizer.step()
        counter += len(autism_data) + len(no_autism_data)

    train_loss_autism = running_loss_autism / counter
    train_loss_no_autism = running_loss_no_autism / counter
    train_age_loss_autism = running_age_loss_autism / counter
    train_gender_loss_autism = running_gender_loss_autism / counter
    train_age_loss_no_autism = running_age_loss_no_autism / counter
    train_gender_loss_no_autism = running_gender_loss_no_autism / counter

    return train_loss_autism, train_loss_no_autism, train_age_loss_autism, train_gender_loss_autism, train_age_loss_no_autism, train_gender_loss_no_autism


## Model specification

These values still need to be adapted for the current model.

In [None]:
input_dimension = 97 # The numer of features
indermediate_dim = 128
latent_dim = 4 # latent dimension for sampling

lr = 0.01



Next I want to define the contrastive variational autoencoder. While doing so, I am defining seperate encoders, to make it easier to later introduce other encoders. I am orienting myself on an cVAE I have written in the past.

As the paper from Aglinskas, Hartshorne and Anzellotti (2022) I mentioned, the network will have only a few layers.

A few things I will probably have to change - I do not know how many channels the data will end up having. therefore I am using one, assuming it only has one.

In [None]:
class EncoderNS(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(EncoderNS, self).__init__()
        self.linear1 = nn.Linear(input_dimension, 64)
        self.linear2 = nn.Linear(64, 32)
        self.linear3 = nn.Linear(32, 4)
        self.ns_fc_mean = nn.Linear(latent_dim, latent_dim)
        self.ns_fc_log_var = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, batch_size):
        h = F.relu(self.linear1(x))
        h = F.relu(self.linear2(h))
        h = F.relu(self.linear3(h))
        ns_mean = self.ns_fc_mean(h)
        ns_log_var = self.ns_fc_log_var(h)
        return ns_mean, ns_log_var


class EncoderS(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(EncoderS, self).__init__()
        self.linear1 = nn.Linear(input_dimension, 64)
        self.linear2 = nn.Linear(64, 32)
        self.linear3 = nn.Linear(32, 4)
        self.s_fc_mean = nn.Linear(latent_dim, latent_dim)
        self.s_fc_log_var = nn.Linear(latent_dim, latent_dim)

    def forward(self, x, batch_size):
        h = F.relu(self.linear1(x))
        h = F.relu(self.linear2(h))
        h = F.relu(self.linear3(h))
        s_mean = self.s_fc_mean(h)
        s_log_var = self.s_fc_log_var(h)
        return s_mean, s_log_var

class Decoder(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(Decoder, self).__init__()
        self.linear_decoder_1 = nn.Linear(latent_dim*2, 32)
        self.linear_decoder_2 = nn.Linear(32,64)
        self.linear_decoder_3 = nn.Linear(64, input_dimension)

    def forward(self, zs, batch_size):
        h_output = F.relu(self.linear_decoder_1(zs))
        h_output = F.relu(self.linear_decoder_2(h_output))
        output = F.relu(self.linear_decoder_3(h_output))
        return output

class Classifier(nn.Module):
    def __init__(self, latent_dim):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(latent_dim*2, latent_dim // 2)
        self.fc_age = nn.Linear(latent_dim // 2, 1)
        self.fc_gender = nn.Linear(latent_dim // 2, 1)

    def forward(self, z):
        x = self.fc1(z)
        age_prediction = self.fc_age(x)
        gender_prediction = torch.sigmoid(self.fc_gender(x))  # Apply sigmoid activation for binary gender prediction
        return age_prediction, gender_prediction

class cVAE(nn.Module):
    def __init__(self, input_dimension, latent_dim):
        super(cVAE, self).__init__()
        self.encoder_z = EncoderNS(input_dimension, latent_dim)
        self.encoder_s = EncoderS(input_dimension, latent_dim)
        self.decoder = Decoder(input_dimension, latent_dim)
        self.classifier = Classifier(latent_dim)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mean + epsilon * std

    def forward(self, autism, no_autism):
        batch_size = autism.size(0)
        z_mean, z_log_var = self.encoder_z(autism, batch_size)
        z = self.reparameterize(z_mean, z_log_var)
        s_mean, s_log_var = self.encoder_s(autism, batch_size)
        s = self.reparameterize(s_mean, s_log_var)
        zs = torch.cat([z, s], dim=1)

        reconstructed_data_autism = self.decoder(zs, batch_size)

        z_mean_no_autism, z_log_var_no_autism = self.encoder_z(no_autism, batch_size)
        z_no_autism = self.reparameterize(z_mean_no_autism, z_log_var_no_autism)
        z_empty = torch.zeros(z_no_autism.shape)
        model_device = z_no_autism.device
        z_empty = z_empty.to(model_device)
        z_no_autism_0 = torch.cat([z_no_autism, z_empty], dim=1)

        reconstructed_data_no_autism = self.decoder(z_no_autism_0, batch_size)

        class_autism_age, class_autism_gender = self.classifier(zs)  # Assuming z is the latent variable after concatenating s and z
        class_no_autism_age, class_no_autism_gender = self.classifier(z_no_autism_0)  # Using the version with 0s to have equal lengths of the latent vectors.

        return z_mean, z_log_var, s_mean, s_log_var, z_mean_no_autism, z_log_var_no_autism, reconstructed_data_autism, reconstructed_data_no_autism, class_autism_age, class_autism_gender, class_no_autism_age, class_no_autism_gender

And finally the training loop - note that I have yet to define the validation function:

In [None]:
model = cVAE(input_dimension=97, latent_dim=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
classifier_criterion = nn.BCELoss()

train_loss_list = []  # List to store train losses
val_loss_list = []  # List to store validation losses

num_epochs = 10
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} of {num_epochs}")
    # Train the model
    (   train_loss_autism,
        train_loss_no_autism,
        train_age_loss_autism,
        train_gender_loss_autism,
        train_age_loss_no_autism,
        train_gender_loss_no_autism) = train(
        model,
        combined_dataloader,
        combined_dataset,
        device,
        optimizer,
        criterion,
        classifier_criterion,
    )

    # Validate the model
    # val_loss, recon_images = validate(model, overlaid_dataloader, overlaid_dataset, device, criterion, classifier_criterion)

    # Appending the loss values to a list to allow for visualizations:
    train_loss_list.append(
        train_loss_autism
        + train_loss_no_autism
        + train_age_loss_autism
        + train_gender_loss_autism
        + train_age_loss_no_autism
        + train_gender_loss_no_autism
    )
    # val_loss_list.append(val_loss)

    # Print the losses
    print(
        f"Train Loss Autism: {train_loss_autism:.4f}, Train Loss Non-Autism: {train_loss_no_autism:.4f}, "
        f"Train Age Loss Autism: {train_age_loss_autism:.4f}, Train Gender Loss Autism: {train_gender_loss_autism:.4f}, "
        f"Train Age Loss Non-Autism: {train_age_loss_no_autism:.4f}, Train Gender Loss Non-Autism: {train_gender_loss_no_autism:.4f}"
    )
    # print(f"Train Loss for the background: {train_loss_bg:.4f}, Val Loss: {val_loss:.4f}")

print("TRAINING COMPLETE")


Epoch 1 of 10


9it [00:00, 49.19it/s]                       


Train Loss Autism: 3.0942, Train Loss Non-Autism: 2.9689, Train Age Loss Autism: 2.8864, Train Gender Loss Autism: 0.0070, Train Age Loss Non-Autism: 2.8629, Train Gender Loss Non-Autism: 0.0068
Epoch 2 of 10


9it [00:00, 45.80it/s]                       


Train Loss Autism: 2.9134, Train Loss Non-Autism: 2.8607, Train Age Loss Autism: 2.8247, Train Gender Loss Autism: 0.0065, Train Age Loss Non-Autism: 2.8137, Train Gender Loss Non-Autism: 0.0064
Epoch 3 of 10


9it [00:00, 49.46it/s]                       


Train Loss Autism: 2.8056, Train Loss Non-Autism: 2.7804, Train Age Loss Autism: 2.7597, Train Gender Loss Autism: 0.0059, Train Age Loss Non-Autism: 2.7494, Train Gender Loss Non-Autism: 0.0060
Epoch 4 of 10


9it [00:00, 51.06it/s]                       


Train Loss Autism: 2.6960, Train Loss Non-Autism: 2.6812, Train Age Loss Autism: 2.6564, Train Gender Loss Autism: 0.0053, Train Age Loss Non-Autism: 2.6507, Train Gender Loss Non-Autism: 0.0054
Epoch 5 of 10


9it [00:00, 56.08it/s]                       


Train Loss Autism: 2.5903, Train Loss Non-Autism: 2.6011, Train Age Loss Autism: 2.5338, Train Gender Loss Autism: 0.0047, Train Age Loss Non-Autism: 2.5590, Train Gender Loss Non-Autism: 0.0049
Epoch 6 of 10


9it [00:00, 61.82it/s]                       


Train Loss Autism: 2.4595, Train Loss Non-Autism: 2.4590, Train Age Loss Autism: 2.3620, Train Gender Loss Autism: 0.0043, Train Age Loss Non-Autism: 2.3910, Train Gender Loss Non-Autism: 0.0047
Epoch 7 of 10


9it [00:00, 91.88it/s]               


Train Loss Autism: 2.2445, Train Loss Non-Autism: 2.2818, Train Age Loss Autism: 2.0729, Train Gender Loss Autism: 0.0039, Train Age Loss Non-Autism: 2.1739, Train Gender Loss Non-Autism: 0.0041
Epoch 8 of 10


9it [00:00, 91.23it/s]               


Train Loss Autism: 1.9591, Train Loss Non-Autism: 2.0631, Train Age Loss Autism: 1.6783, Train Gender Loss Autism: 0.0038, Train Age Loss Non-Autism: 1.8949, Train Gender Loss Non-Autism: 0.0043
Epoch 9 of 10


9it [00:00, 101.46it/s]              


Train Loss Autism: 1.7742, Train Loss Non-Autism: 1.8530, Train Age Loss Autism: 1.3892, Train Gender Loss Autism: 0.0038, Train Age Loss Non-Autism: 1.6112, Train Gender Loss Non-Autism: 0.0042
Epoch 10 of 10


9it [00:00, 97.78it/s]               


Train Loss Autism: 1.5570, Train Loss Non-Autism: 1.6691, Train Age Loss Autism: 1.0956, Train Gender Loss Autism: 0.0036, Train Age Loss Non-Autism: 1.3558, Train Gender Loss Non-Autism: 0.0040
TRAINING COMPLETE


Questions to Answer:
How to best deal with with the issue that my data is in the shape of [batch size, 1] but the expected category against which its compared is [batch_size]

Need to now prepare the validation data set to properly set the learning parameters.