In [1]:
import nibabel
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

# Custom modules
from preprocessing_post_fastsurfer.subject import *
from preprocessing_post_fastsurfer.vis import *

In [2]:

class SubjectDataset(Dataset):
    def __init__(self, data_path):
        
        self.subjects_list = find_subjects(data_path)

    def __len__(self):
        
        return len(self.subjects_list)

    def __getitem__(self, index):
        
        subject = self.subjects_list[index]
        
        
        """IMAGES"""
        
        # Aligned cropped brain
        brain = self.load_mri_to_tensor(subject.brain_aligned_cropped)
        
        # NB these are all cropped
        hcampus_vox = self.load_mri_to_tensor(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped.nii'))
        
        hcampus_vox_aligned = self.load_mri_to_tensor(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_aligned_cropped.nii'))
    
        hcampus_pointcloud = torch.tensor(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh_downsampled.npy')), dtype=torch.float32)
                                          
        hcampus_pointcloud_aligned = torch.tensor(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_aligned_cropped_mesh_downsampled.npy')), dtype=torch.float32)
        
        
        """REGION VOLUME STATS"""
        
        aseg_stats = subject.aseg_stats
        
        
        """SUBJECT INFO - NB NOT COMPLETE WITH SCORES, NEED TO PARSE XML"""
    
        # Info from subject XML
        subject_dict = subject.xml_df
    
        research_group = subject_dict['idaxs']['project']['subject']['researchGroup']
        
        visit_identifier = subject_dict['idaxs']['project']['subject']['visit']['visitIdentifier']
        
        sex = subject_dict['idaxs']['project']['subject']['subjectSex']
        
        age = subject_dict['idaxs']['project']['subject']['study']['subjectAge']
        
        weight = subject_dict['idaxs']['project']['subject']['study']['weightKg']
        
        apoe_a1 = subject_dict['idaxs']['project']['subject']['subjectInfo'][0]['#text']
        
        apoe_a2 = subject_dict['idaxs']['project']['subject']['subjectInfo'][1]['#text']
            
        
        # Scores
        

        # Return a dictionary with your data
        return {
            'brain': brain,
            'hcampus_vox': hcampus_vox,
            'hcampus_vox_aligned': hcampus_vox_aligned,
            'hcampus_pointcloud': hcampus_pointcloud,
            'hcampus_pointcloud_aligned': hcampus_pointcloud_aligned,
            'aseg_stats': aseg_stats,
            'research_group': research_group,
            'visit_identifier': visit_identifier,
            'sex': sex,
            'age': age,
            'weight': weight,
            'apoe_a1': apoe_a1,
            'apoe_a2': apoe_a2,
        }
        
    def load_mri_to_tensor(self, path):
        
        if path is None or not os.path.isfile(path):
            return torch.empty(0)  # Return empty tensor if the file doesn't exist
        
        # Example of using nibabel to load .mgz files (you can modify as needed)
        image = nibabel.load(path)

        image_data = image.get_fdata()
        
        # Convert to PyTorch tensor
        tensor_data = torch.tensor(image_data, dtype=torch.float32)
        
        return tensor_data

    

In [3]:
data_path = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/scratch disk/full-datasets/adni1-complete-3T-processed"

In [4]:
# Create a pytorch dataset using the subject list
dataset = SubjectDataset(data_path)

In [5]:
'''
for subject in dataset.subjects_list:
    
    research_group = subject.xml_df['idaxs']['project']['subject']['researchGroup']
    
    print(research_group)
    
    display_mesh(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh.npy')), downsample_factor=1)'''

"\nfor subject in dataset.subjects_list:\n    \n    research_group = subject.xml_df['idaxs']['project']['subject']['researchGroup']\n    \n    print(research_group)\n    \n    display_mesh(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh.npy')), downsample_factor=1)"

In [6]:
print(len(dataset.subjects_list))

475


In [7]:
def split_dataset(dataset, test_size=0.2):
    
    dataset_size = len(dataset)
    
    test_size = int(test_size * dataset_size)
    
    train_size = dataset_size - test_size
    
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    return train_dataset, test_dataset

In [8]:
train_data, test_data = split_dataset(dataset)

train_dataloader = DataLoader(train_data)

test_dataloader = DataLoader(test_data)



In [9]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

print(f"Using {device} device")

Using cuda device


In [15]:
print(dataset[0]['hcampus_pointcloud'].shape)

torch.Size([4930, 3])


In [11]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits