In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import torch.optim as optim

# Benny pointnet
from pointnet2_benny import pointnet2_cls_msg
from pointnet2_benny import provider

# Other
from tqdm import tqdm
import nibabel
import os
import numpy as np

# Custom modules
from preprocessing_post_fastsurfer.subject import *
from preprocessing_post_fastsurfer.vis import *

In [2]:

class SubjectDataset(Dataset):
    def __init__(self, data_path):
        
        self.subjects_list = find_subjects(data_path)

    def __len__(self):
        
        return len(self.subjects_list)

    def __getitem__(self, index):
        
        subject = self.subjects_list[index]
        
        
        """IMAGES"""
        
        # Aligned cropped brain
        brain = self.load_mri_to_tensor(subject.brain_aligned_cropped)
        
        # NB these are all cropped
        hcampus_vox = self.load_mri_to_tensor(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped.nii'))
        
        hcampus_vox_aligned = self.load_mri_to_tensor(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_aligned_cropped.nii'))
    
        hcampus_pointcloud = torch.tensor(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh_downsampled.npy')), dtype=torch.float32)
                                          
        hcampus_pointcloud_aligned = torch.tensor(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_aligned_cropped_mesh_downsampled.npy')), dtype=torch.float32)
        
        
        """REGION VOLUME STATS"""
        
        aseg_stats = subject.aseg_stats
        
        
        """SUBJECT INFO - NB NOT COMPLETE WITH SCORES, NEED TO PARSE XML"""
    
        # Info from subject XML
        subject_dict = subject.xml_df
        
        # Convert research group disease label str to number for pytorch
        mapping = {
            'CN': 0,
            'MCI': 1,
            'AD': 2
        }
        
        # Get the value of the mapping, -1 if not found
        research_group = mapping.get(subject_dict['idaxs']['project']['subject']['researchGroup'], -1)
        
        visit_identifier = subject_dict['idaxs']['project']['subject']['visit']['visitIdentifier']
        
        sex = subject_dict['idaxs']['project']['subject']['subjectSex']
        
        age = subject_dict['idaxs']['project']['subject']['study']['subjectAge']
        
        weight = subject_dict['idaxs']['project']['subject']['study']['weightKg']
        
        apoe_a1 = subject_dict['idaxs']['project']['subject']['subjectInfo'][0]['#text']
        
        apoe_a2 = subject_dict['idaxs']['project']['subject']['subjectInfo'][1]['#text']
            
        
        # Scores
        

        # Return a dictionary with your data
        return {
            #'brain': brain,
            #'hcampus_vox': hcampus_vox,
            #'hcampus_vox_aligned': hcampus_vox_aligned,
            'hcampus_pointcloud': hcampus_pointcloud,
            #'hcampus_pointcloud_aligned': hcampus_pointcloud_aligned,
            #'aseg_stats': aseg_stats,
            'research_group': research_group,
            #'visit_identifier': visit_identifier,
            #'sex': sex,
            #'age': age,
            #'weight': weight,
            #'apoe_a1': apoe_a1,
            #'apoe_a2': apoe_a2,
        }
        
    def load_mri_to_tensor(self, path):
        
        if path is None or not os.path.isfile(path):
            return torch.empty(0)  # Return empty tensor if the file doesn't exist
        
        # Example of using nibabel to load .mgz files (you can modify as needed)
        image = nibabel.load(path)

        image_data = image.get_fdata()
        
        # Convert to PyTorch tensor
        tensor_data = torch.tensor(image_data, dtype=torch.float32)
        
        return tensor_data

    

In [3]:
data_path = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/scratch disk/full-datasets/adni1-complete-3T-processed"

In [4]:
# Create a pytorch dataset using the subject list
dataset = SubjectDataset(data_path)

In [5]:
'''
for subject in dataset.subjects_list:
    
    research_group = subject.xml_df['idaxs']['project']['subject']['researchGroup']
    
    print(research_group)
    
    display_mesh(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh.npy')), downsample_factor=1)'''

"\nfor subject in dataset.subjects_list:\n    \n    research_group = subject.xml_df['idaxs']['project']['subject']['researchGroup']\n    \n    print(research_group)\n    \n    display_mesh(np.load(os.path.join(subject.path, 'Left-Hippocampus_Right-Hippocampus_cropped_mesh.npy')), downsample_factor=1)"

In [6]:
print(len(dataset.subjects_list))

475


In [7]:
def split_dataset(dataset, test_size=0.2):
    
    dataset_size = len(dataset)
    
    test_size = int(test_size * dataset_size)
    
    train_size = dataset_size - test_size
    
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    return train_dataset, test_dataset

In [8]:
train_data, test_data = split_dataset(dataset)

train_dataloader = DataLoader(train_data, batch_size = 16, shuffle=True)

test_dataloader = DataLoader(test_data, batch_size = 16, shuffle=False)



In [9]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

print(f"Using {device} device")

Using cuda device


In [10]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [11]:
torch.cuda.empty_cache()

# Training loop configuration
num_epochs = 5
num_classes = 3  # Update this to match the number of classes in your dataset

# Define the model
model = pointnet2_cls_msg.get_model(num_classes, normal_channel=False)

# Initialize the loss function and optimizer
criterion = pointnet2_cls_msg.get_loss()

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Set up the device for training (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total_points = 0

    for batch_idx, dict in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        points = dict['hcampus_pointcloud']  # Points data from your dataset
        labels = dict['research_group']  # Labels (ground truth)
        
        # Benny script augmentation
        '''
        points = points.numpy()
        points = provider.random_point_dropout(points)
        points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
        points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
        points = torch.Tensor(points)'''
        
        # Transpose as in benny script (NB why does it need a transpose)
        points = points.transpose(2, 1)
        
    # Move data to the correct device
        points, labels = points.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        pred, _ = model(points)  # Assuming the model returns (predictions, features)

        # Calculate loss, trans_feat argument as None as not used in this function
        loss = criterion(pred, labels, None)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(pred, 1)
        correct += (predicted == labels).sum().item()
        total_points += labels.size(0)
        total_loss += loss.item()

    # Calculate and log accuracy for the epoch
    epoch_acc = 100 * correct / total_points
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

# Save the trained model
torch.save(model.state_dict(), 'trained_model.pth')
print("Training complete and model saved!")


100%|██████████| 24/24 [00:40<00:00,  1.68s/it]


Epoch 1/5, Loss: 28.9137, Accuracy: 35.79%


100%|██████████| 24/24 [00:18<00:00,  1.30it/s]


Epoch 2/5, Loss: 25.9028, Accuracy: 44.21%


100%|██████████| 24/24 [00:18<00:00,  1.31it/s]


Epoch 3/5, Loss: 25.8059, Accuracy: 46.84%


100%|██████████| 24/24 [00:18<00:00,  1.31it/s]


Epoch 4/5, Loss: 25.8583, Accuracy: 42.37%


100%|██████████| 24/24 [00:18<00:00,  1.31it/s]

Epoch 5/5, Loss: 24.9977, Accuracy: 48.16%
Training complete and model saved!





In [12]:
'''import torch.optim as optim
from pointnet2_benny import pointnet2_cls_msg

# Training loop
num_epochs = 5
num_class = 2

model = pointnet2_cls_msg

classifier = model.get_model(3)
    
for dict in train_dataloader:
    
    print(dict['hcampus_pointcloud'].shape)
    
    print(dict['research_group'].shape)
    
    break

# Initialize model and loss function
num_classes = 40  # Update to match your dataset
classifier = pointnet2_cls_msg.get_model(num_classes)  # Update the method to match your model
criterion = pointnet2_cls_msg.get_loss()  # Make sure your model has a loss function defined

# Set up optimizer
optimizer = optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier.to(device)

# Training loop
num_epochs = 5  # Update as needed
for epoch in range(num_epochs):
    classifier.train()
    total_loss = 0
    correct = 0
    total_points = 0
    for dict in train_dataloader:
        
        points = dict['hcampus_pointcloud']
        
        print(points.shape)
        
        labels = dict['research_group']        
        
        points, labels = points.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        pred, _ = classifier(points)
        loss = criterion(pred, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(pred, 1)
        correct += (predicted == labels).sum().item()
        total_points += labels.size(0)
        total_loss += loss.item()

    epoch_acc = 100 * correct / total_points
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

# Save the trained model
torch.save(classifier.state_dict(), 'trained_model.pth')
print("Training complete and model saved!")
'''

'import torch.optim as optim\nfrom pointnet2_benny import pointnet2_cls_msg\n\n# Training loop\nnum_epochs = 5\nnum_class = 2\n\nmodel = pointnet2_cls_msg\n\nclassifier = model.get_model(3)\n    \nfor dict in train_dataloader:\n    \n    print(dict[\'hcampus_pointcloud\'].shape)\n    \n    print(dict[\'research_group\'].shape)\n    \n    break\n\n# Initialize model and loss function\nnum_classes = 40  # Update to match your dataset\nclassifier = pointnet2_cls_msg.get_model(num_classes)  # Update the method to match your model\ncriterion = pointnet2_cls_msg.get_loss()  # Make sure your model has a loss function defined\n\n# Set up optimizer\noptimizer = optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-4)\n\n# Move model to GPU if available\ndevice = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')\nclassifier.to(device)\n\n# Training loop\nnum_epochs = 5  # Update as needed\nfor epoch in range(num_epochs):\n    classifier.train()\n    total_loss = 0\n    cor