In [44]:
# Torch
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
from torcheval.metrics import *

# 3d cnn
import cnn3d_xmuyzz

# Other
from tqdm import tqdm
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import pickle
# Custom modules
from preprocessing_post_fastsurfer.subject import *
from preprocessing_post_fastsurfer.vis import *
from ozzy_torch_utils.split_dataset import *
from ozzy_torch_utils.SubjectDataset import *

### Dataset hyperparameters

In [45]:
data_path = "/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/scratch-disk/full-datasets/hcampus-large-cohort"

selected_labels = ['CN', 'MCI']

# Dictionary key representing the data of interest
data_string = 'hcampus_vox_aligned'

# Dictionary key representing the disease labels
labels_string = 'research_group'

# Prevent class imbalance
downsample_majority = True

# NB this argument makes prevent_id_leakage redundant
single_img_per_subject = False

# Prevent the same subject id from occuring in train and test, in case of more than one image per id
prevent_id_leakage = True

batch_size = 20

test_size = 0.3

### Dataset creation

In [46]:
dataset = SubjectDataset(data_path, selected_labels, downsample_majority=downsample_majority, single_img_per_subject=single_img_per_subject)

### Data checks

Check the size of the dataset and the number of unique labels and IDs

In [47]:
print(f"Dataset size: {len(dataset)}\n")

labels = [dataset[index]['research_group'] for index in range(len(dataset.subject_list))]

ids = [dataset.subject_list[index].subject_metadata['Subject'] for index in range(len(dataset.subject_list))]

print(f"Unique labels: {np.unique(labels, return_counts=True)}\n")

print(f"Unique ids: {np.unique(ids, return_counts=True)}\n")


Dataset size: 3926

Unique labels: (array([0, 1]), array([1963, 1963]))

Unique ids: (array(['003_S_0908', '003_S_1074', '003_S_4081', '003_S_4119',
       '003_S_4288', '003_S_4350', '003_S_4441', '003_S_4555',
       '003_S_4644', '003_S_4872', '003_S_4900', '003_S_5154',
       '005_S_0324', '005_S_0448', '005_S_0553', '005_S_0572',
       '005_S_0602', '005_S_4168', '007_S_1206', '007_S_1222',
       '007_S_2394', '007_S_4272', '007_S_4387', '007_S_4488',
       '007_S_4516', '007_S_4620', '007_S_4637', '007_S_5265',
       '009_S_0751', '009_S_0842', '009_S_1030', '009_S_4337',
       '009_S_4388', '009_S_4612', '011_S_0021', '011_S_0023',
       '011_S_1080', '011_S_1282', '011_S_4075', '011_S_4105',
       '011_S_4120', '011_S_4222', '011_S_4278', '013_S_1035',
       '013_S_1186', '013_S_1276', '013_S_4579', '013_S_4580',
       '013_S_4616', '014_S_4080', '014_S_4093', '014_S_4401',
       '014_S_4576', '014_S_4577', '016_S_0769', '016_S_1117',
       '016_S_1121', '016_S_1138

### Loader config

In [48]:

train_data, test_data = split_dataset(dataset, test_size=test_size, prevent_id_leakage=prevent_id_leakage)

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn([data_string, labels_string]))

test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn([data_string, labels_string]))

### Data checks
Check if there are subjects split across train and test

In [49]:

train_ids = [dataset.subject_list[index].subject_metadata['Subject'].iloc[0] for index in train_data.indices]

test_ids = [dataset.subject_list[index].subject_metadata['Subject'].iloc[0] for index in test_data.indices]

print(f"Id intersection between train and test: {np.intersect1d(np.unique(train_ids), np.unique(test_ids))}\n")

Id intersection between train and test: []



### Cuda setup

In [50]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

print(f"Using {device} device")

Using cuda device


### Training hyperparameters

In [51]:
num_epochs = 200

learning_rate = 0.001

threshold = 0.5

In [52]:
print(dataset[0][data_string].shape)

torch.Size([0])


In [53]:
import cnn3d_xmuyzz.ResNetV2


metrics = {
    "training_losses" : [],
    "validation_losses": [],
    "conf_matrices": [],
    "accuracies": [],
    "f1s": [],
    "precisions": [],
    "recalls": [],
    "train_time": None,
    "num_training_images": None
}

model = cnn3d_xmuyzz.ResNetV2.generate_model(
            model_depth=18,
            n_classes=2,
            n_input_channels=1,
            shortcut_type='B',
            conv1_t_size=7,
            conv1_t_stride=1,
            no_max_pool=False,
            widen_factor=1.0)

criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(
            model.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=1e-4,
            amsgrad=True
        )

model.to(device)

start_time = datetime.now()

for epoch in range(num_epochs):
    
    print(f"Starting epoch {epoch + 1}\n")
    
    # Training loop
    model.train()
    
    running_loss = 0.0

    for batch_idx, dict in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        
        # Access dict returned by dataset __getitem__
        points = dict[data_string]
        labels = dict[labels_string]
        
        points, labels = points.to(device), labels.to(device)
        
        # Transform to have redundant channel of dimension 1 for model
        points = points.unsqueeze(1)

        # Forward pass
        output = model(points)

        # Calculate loss, trans_feat argument as None as not used in this function
        loss = criterion(output, labels)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Multiply loss by batch size to account for differences in batch size (e.g last batch)
        running_loss += loss.item() * points.size(0)
        
    metrics['training_losses'].append(running_loss/len(train_dataloader))
    
    end_time = datetime.now()
    
    # Validation loop
    model.eval()
    
    # Initialise loop metrics
    running_loss = 0.0; conf_matrix = BinaryConfusionMatrix(); accuracy = BinaryAccuracy(); f1 = BinaryF1Score(); precision = BinaryPrecision(); recall = BinaryRecall()
    
    with torch.no_grad():
        
        for batch_idx, dict in enumerate(test_dataloader):
            
            points = dict[data_string]
            labels = dict[labels_string]
                     
            points, labels = points.to(device), labels.to(device)
            
            # Transform to have redundant channel of dimension 1 for model
            points = points.unsqueeze(1)
            
            output = model(points)
            
            running_loss += criterion(output, labels).item() * points.size(0)
            
            # Apply exponent as the output of the model is log softmax
            pred_probability = torch.exp(output)
            
            # Threshold is variable to give preference to FN or FP
            pred_labels = (pred_probability[:, 1] >= threshold).int()
            
            # Old label conversion
            # pred_labels = torch.argmax(pred_probability, dim=-1)

            # Update metrics
            [metric.update(pred_labels, labels) for metric in [conf_matrix, accuracy, f1, precision, recall]]

    end_time = datetime.now()
            
    # Append metric lists
    [metrics[key].append(metric.compute()) for key, metric in [("conf_matrices", conf_matrix), ("accuracies", accuracy), ("f1s", f1), ("precisions", precision), ("recalls", recall)]]       
         
    metrics['validation_losses'].append(running_loss/len(test_dataloader))
    
    print(f"\nEpoch {epoch + 1} complete\n")
    print("------------------------")
    print(conf_matrix.compute())
    print(f"Training Loss:   {metrics['training_losses'][-1]:.4f}")
    print(f"Validation Loss: {metrics['validation_losses'][-1]:.4f}")
    print(f"Accuracy:        {metrics['accuracies'][-1]:.4f}")
    print(f"F1 Score:        {metrics['f1s'][-1]:.4f}")
    print(f"Precision:       {metrics['precisions'][-1]:.4f}")
    print(f"Recall:          {metrics['recalls'][-1]:.4f}")
    print("------------------------\n\n")
        
    # Break before nightly restart
    current_time = datetime.now()
    
    if current_time.hour == 23 and current_time.minute >= 30:
        
        print("Break before nightly restart")
        
        break
    
metrics['train_time'] = end_time - start_time
metrics['num_training_images'] = len(test_data)

torch.save(model.state_dict(), 'trained_model.pth')

print("Training complete and model saved")

Starting epoch 1



  0%|          | 0/175 [00:00<?, ?it/s]


RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [16, 1, 0]

### Plotting

In [38]:
# NB this function has to remain in the notebook for it to work properly
# Plot training loss, validation loss, and accuracy on separate subplots, along with displaying hyperparameters
def plot(metrics, model_name, param_list, save=True, ylim=None):

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)

    ax1.plot(metrics['training_losses'], label='Training Loss', color='blue')
    ax1.plot(metrics['validation_losses'], label='Validation Loss', color='red')
    ax1.set_title('Training and Validation Loss over Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    if ylim is not None:
        
        ax1.set_ylim(ylim)
        
    else:
        
        ax1.set_ylim(0, metrics['training_losses'][0] + 5)
        
    ax1.grid(True)
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    ax2.plot(metrics['accuracies'], label='Accuracy', color='green')
    ax2.plot(metrics['f1s'], label='F1 Score', color='blue')
    ax2.plot(metrics['precisions'], label='Precision', color='red')
    ax2.plot(metrics['recalls'], label='Recall', color='orange')
    ax2.set_title('Metrics over epochs')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Value')
    ax2.legend()
    ax2.grid(True)
    ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    minutes = metrics['train_time'].seconds // 60

    seconds = metrics['train_time'].seconds % 60

    train_time_str = f"Training time: {minutes:02d}m {seconds:02d}s"

    info = []
    info.append(train_time_str)
        
    try:
        
        info.append(f"Number of training images: {metrics['num_training_images']:.0f}")
        
    except:
        
        print("Error with num_training_images")
        
        
    info.append(f"Model name: {model_name}")
    info.append(f"Best accuracy: {max(metrics['accuracies']):.2f}")
    info.append(f"Best F1 Score: {max(metrics['f1s']):.2f}")
    info.append(f"Best Precision: {max(metrics['precisions']):.2f}")
    info.append(f"Best Recall: {max(metrics['recalls']):.2f}")
    info.append(f"Epoch with smallest validation loss: {metrics['validation_losses'].index(min(metrics['validation_losses'])):.0f}")
    info.append("\n\n")

    # Nasty hack using globals() to get variable names automatically
    for param in param_list:
        
        for name, value in globals().items():
            
            if value is param and name not in [info_line.split(":")[0] for info_line in info]:
                
                info.append(f"{name}: {value}")
    
    info_text = "\n".join(info)
    
    fig.text(0.5, 0.02, info_text, ha='center', va='top', wrap=True, fontsize=10)

    if save:
        
        # Save the fig and the lists of values
        current_time = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        
        name = f"plot_{current_time}"
        
        with open(f'/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/figs/{name}.pkl', 'wb') as file:
            
            pickle.dump(metrics, file)
        
        plt.savefig(f'/uolstore/home/student_lnxhome01/sc22olj/Compsci/year3/individual-project-COMP3931/individual-project-sc22olj/figs/{name}.png', bbox_inches='tight')
    
    plt.show()
    
    return

In [None]:

plot(metrics, "ResNetV2", [selected_labels, data_string, labels_string, downsample_majority, single_img_per_subject, prevent_id_leakage, batch_size, test_size, learning_rate, num_epochs, threshold])
