In [1]:
import sys
import os

# Compute path to project root (one level above 'examples/')
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [17]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
from data.nifti_loader import MedicalImageDatasetSplitter,MonaiDatasetCreator,MonaiDataLoaderManager
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, IntSlider, fixed
from config import config_loader
from utils.visualization import show_slice_advanced

import torch.nn as nn
from models.cnn_backbones import Small3DCNN,DenseNet3D,ResNet3D,BasicBlock3D
from training.trainer import train_model, load_model, save_model,test_model


## Basic data preprocessing and loading 

In [3]:
config_path = os.path.abspath(os.path.join(os.getcwd(), "..", "config", "base_config.yaml"))

# Load & process config (creates directories once)
config = config_loader.load_config(config_path)

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
num_classes = dataset_splitter.get_num_classes()

2025-04-30 11:31:54,300 - INFO - Created directory: results/exp_1_20250430
2025-04-30 11:31:54,302 - INFO - Created directory: results/exp_1_20250430/logs
2025-04-30 11:31:54,304 - INFO - Created directory: results/exp_1_20250430/checkpoints
2025-04-30 11:31:54,305 - INFO - Created directory: results/exp_1_20250430/explanations
2025-04-30 11:31:54,306 - INFO - Created directory: results/exp_1_20250430/predictions
2025-04-30 11:31:54,307 - INFO - Created directory: results/exp_1_20250430/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [4]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape


In [5]:
def show_slice(z: int):
    """
    Display the z-th slice of the 3D volume.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(sample_volume[:, :, z], cmap='gray')
    plt.title(f"Slice {z+1}/{D}")
    plt.axis('off')
    plt.show()


In [6]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

In [7]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=172), Output()),…

### Overriding config values for ablation 

In [8]:
# 1) Load the default config 

config_path = os.path.abspath(os.path.join(os.getcwd(), "..", "config", "base_config.yaml"))


# process config (creates directories once)
config = config_loader.load_config(config_path)


# 2) Tweak any values you want on the fly
config['data']['batch_size']      = 8
config['data']['perform_slicing'] = False
config['data']['image_size'] = [128, 128, 128]
# ...

# Create dataset splitter
dataset_splitter = MedicalImageDatasetSplitter(config)

# Create dataset creator
dataset_creator = MonaiDatasetCreator(dataset_splitter)

# Create dataloader manager
dataloader_manager = MonaiDataLoaderManager(dataset_creator, config)

# Get all dataloaders
dataloaders = dataloader_manager.get_dataloaders()

# Access individual dataloaders
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

# Get class information
class_to_idx, idx_to_class = dataset_splitter.get_class_info()
NUM_CLASSES = dataset_splitter.get_num_classes()

2025-04-30 11:31:58,361 - INFO - Created directory: results/exp_1_20250430
2025-04-30 11:31:58,363 - INFO - Created directory: results/exp_1_20250430/logs
2025-04-30 11:31:58,364 - INFO - Created directory: results/exp_1_20250430/checkpoints
2025-04-30 11:31:58,365 - INFO - Created directory: results/exp_1_20250430/explanations
2025-04-30 11:31:58,366 - INFO - Created directory: results/exp_1_20250430/predictions
2025-04-30 11:31:58,367 - INFO - Created directory: results/exp_1_20250430/figures


Number of empty labels: 0
Checking filtering status...
No filtering applied.
Loaded dataset with 652 total samples
Training: 456 samples
Validation: 98 samples
Testing: 98 samples
Classes: {'AD': 0, 'CN': 1, 'LMCI': 2}


In [9]:
# Fetch a single batch (first batch) from the train loader
batch = next(iter(train_loader))

idx = np.random.randint(0,batch['image'].shape[0])
# Extract the first sample: batch['image'] shape is [B, C, H, W, D]
sample_volume = batch['image'][idx]  # shape [C, H, W, D]

# Convert to numpy and drop channel dimension if present
if isinstance(sample_volume, torch.Tensor):
    sample_volume = sample_volume.cpu().numpy()
if sample_volume.ndim == 4:
    # assume channel-first, take channel 0
    sample_volume = sample_volume[0]

# Now sample_volume is a 3D array of shape [H, W, D]
H, W, D = sample_volume.shape

In [10]:
batch = next(iter(train_loader))
batch.keys()

dict_keys(['image', 'label'])

In [11]:
# Create an IntSlider for the z-dimension
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=D - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to the display function
widgets.interact(show_slice, z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…

In [12]:
# Slider for the z-dimension
slider = IntSlider(
    value=0,
    min=0,
    max=sample_volume.shape[2] - 1,
    step=1,
    description='Z Slice:',
    continuous_update=False
)

# Link the slider to show_slice, passing sample_volume as a fixed argument
interact(show_slice_advanced, volume=fixed(sample_volume), z=slider);

interactive(children=(IntSlider(value=0, continuous_update=False, description='Z Slice:', max=127), Output()),…

#### Training a predefined 3D CNN

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Small3DCNN(num_classes=NUM_CLASSES).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [14]:
# Train the model on the available device (either GPU or CPU)
trained_model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=10,
    device=device
)

# Accessing the training history
print("Training History:")
print(f"Train Loss: {history['train_loss']}")
print(f"Train Accuracy: {history['train_accuracy']}")
print(f"Validation Loss: {history['val_loss']}")
print(f"Validation Accuracy: {history['val_accuracy']}")

Epoch 1/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.20it/s]
Epoch 1/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.48it/s]


Epoch [1/10], Train Loss: 1.1059, Train Accuracy: 0.3640, Val Loss: 1.0751, Val Accuracy: 0.3878


Epoch 2/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.21it/s]
Epoch 2/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [2/10], Train Loss: 1.0769, Train Accuracy: 0.4057, Val Loss: 1.0631, Val Accuracy: 0.3878


Epoch 3/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.24it/s]
Epoch 3/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [3/10], Train Loss: 1.0825, Train Accuracy: 0.4123, Val Loss: 1.0642, Val Accuracy: 0.3878


Epoch 4/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.19it/s]
Epoch 4/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [4/10], Train Loss: 1.0629, Train Accuracy: 0.4167, Val Loss: 1.0429, Val Accuracy: 0.4694


Epoch 5/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.20it/s]
Epoch 5/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [5/10], Train Loss: 1.0451, Train Accuracy: 0.4254, Val Loss: 1.0292, Val Accuracy: 0.4694


Epoch 6/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.22it/s]
Epoch 6/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [6/10], Train Loss: 1.0390, Train Accuracy: 0.4452, Val Loss: 1.0364, Val Accuracy: 0.4286


Epoch 7/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.19it/s]
Epoch 7/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [7/10], Train Loss: 1.0223, Train Accuracy: 0.4781, Val Loss: 1.0176, Val Accuracy: 0.5204


Epoch 8/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.22it/s]
Epoch 8/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.52it/s]


Epoch [8/10], Train Loss: 1.0058, Train Accuracy: 0.5197, Val Loss: 1.0101, Val Accuracy: 0.4796


Epoch 9/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:49<00:00,  1.15it/s]
Epoch 9/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.56it/s]


Epoch [9/10], Train Loss: 1.0047, Train Accuracy: 0.4956, Val Loss: 1.0190, Val Accuracy: 0.4694


Epoch 10/10 (Training): 100%|████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.21it/s]
Epoch 10/10 (Validation): 100%|██████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]

Epoch [10/10], Train Loss: 0.9722, Train Accuracy: 0.5351, Val Loss: 1.0654, Val Accuracy: 0.4490
Training History:
Train Loss: [1.105942675941869, 1.0768783604889585, 1.0824888578632421, 1.0628528709997211, 1.0451049083157589, 1.038959000194282, 1.0223452239705806, 1.0057881390839292, 1.0047482105723597, 0.9721566543244479]
Train Accuracy: [0.36403508771929827, 0.4057017543859649, 0.41228070175438597, 0.4166666666666667, 0.42543859649122806, 0.4451754385964912, 0.4780701754385965, 0.5197368421052632, 0.4956140350877193, 0.5350877192982456]
Validation Loss: [1.0750879232700055, 1.063114115825066, 1.0642333855995765, 1.042895651780642, 1.0291734154407794, 1.0364421330965483, 1.0175692714177644, 1.010080860211299, 1.0189586511025062, 1.065406391253838]
Validation Accuracy: [0.3877551020408163, 0.3877551020408163, 0.3877551020408163, 0.46938775510204084, 0.46938775510204084, 0.42857142857142855, 0.5204081632653061, 0.47959183673469385, 0.46938775510204084, 0.4489795918367347]





In [15]:
test_loss, test_acc, preds, labels = test_model(trained_model, test_loader, criterion)


Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.58it/s]

Test Loss: 1.1135, Test Accuracy: 0.4388





In [16]:
# save_path = "/path/to/save/model.pth"
# save_model(model, save_path)

###  Testing a basic ResNet3D Model

In [18]:
model = ResNet3D(BasicBlock3D, layers=[2, 2, 2, 2], num_classes=3)


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [19]:
# Train the model on the available device (either GPU or CPU)
trained_model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=10,
    device=device
)

# Accessing the training history
print("Training History:")
print(f"Train Loss: {history['train_loss']}")
print(f"Train Accuracy: {history['train_accuracy']}")
print(f"Validation Loss: {history['val_loss']}")
print(f"Validation Accuracy: {history['val_accuracy']}")

Epoch 1/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:49<00:00,  1.15it/s]
Epoch 1/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.57it/s]


Epoch [1/10], Train Loss: 1.0668, Train Accuracy: 0.4496, Val Loss: 3.6297, Val Accuracy: 0.3878


Epoch 2/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.20it/s]
Epoch 2/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.54it/s]


Epoch [2/10], Train Loss: 1.0313, Train Accuracy: 0.4496, Val Loss: 1.1321, Val Accuracy: 0.3878


Epoch 3/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.24it/s]
Epoch 3/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.53it/s]


Epoch [3/10], Train Loss: 1.0341, Train Accuracy: 0.4254, Val Loss: 1.4501, Val Accuracy: 0.4184


Epoch 4/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.22it/s]
Epoch 4/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.56it/s]


Epoch [4/10], Train Loss: 1.0253, Train Accuracy: 0.4254, Val Loss: 1.1539, Val Accuracy: 0.4286


Epoch 5/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:46<00:00,  1.24it/s]
Epoch 5/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]


Epoch [5/10], Train Loss: 0.9890, Train Accuracy: 0.5329, Val Loss: 1.1160, Val Accuracy: 0.3673


Epoch 6/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.24it/s]
Epoch 6/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch [6/10], Train Loss: 0.9932, Train Accuracy: 0.4956, Val Loss: 1.1885, Val Accuracy: 0.4592


Epoch 7/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.21it/s]
Epoch 7/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.59it/s]


Epoch [7/10], Train Loss: 0.9568, Train Accuracy: 0.5285, Val Loss: 1.6104, Val Accuracy: 0.4082


Epoch 8/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:45<00:00,  1.25it/s]
Epoch 8/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.54it/s]


Epoch [8/10], Train Loss: 0.9444, Train Accuracy: 0.5154, Val Loss: 1.8933, Val Accuracy: 0.4082


Epoch 9/10 (Training): 100%|█████████████████████████████████████████████████████████████████████████████| 57/57 [00:47<00:00,  1.21it/s]
Epoch 9/10 (Validation): 100%|███████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.60it/s]


Epoch [9/10], Train Loss: 0.9813, Train Accuracy: 0.4715, Val Loss: 0.9999, Val Accuracy: 0.4286


Epoch 10/10 (Training): 100%|████████████████████████████████████████████████████████████████████████████| 57/57 [00:48<00:00,  1.19it/s]
Epoch 10/10 (Validation): 100%|██████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.59it/s]

Epoch [10/10], Train Loss: 0.9289, Train Accuracy: 0.5132, Val Loss: 1.0879, Val Accuracy: 0.4490
Training History:
Train Loss: [1.0668172721277203, 1.0312869015492891, 1.0341277132954514, 1.02528918312307, 0.9890300296900565, 0.993183620143355, 0.9567706428076092, 0.9444349376778853, 0.9813297652361685, 0.9289414349355196]
Train Accuracy: [0.44956140350877194, 0.44956140350877194, 0.42543859649122806, 0.42543859649122806, 0.5328947368421053, 0.4956140350877193, 0.5285087719298246, 0.5153508771929824, 0.47149122807017546, 0.5131578947368421]
Validation Loss: [3.6297158094552846, 1.132100916825808, 1.4500975150328417, 1.15388128390679, 1.1159689793219933, 1.1884520283112159, 1.6103506546754103, 1.8932755222687354, 0.9998583977039044, 1.0879319814535289]
Validation Accuracy: [0.3877551020408163, 0.3877551020408163, 0.41836734693877553, 0.42857142857142855, 0.3673469387755102, 0.45918367346938777, 0.40816326530612246, 0.40816326530612246, 0.42857142857142855, 0.4489795918367347]





In [20]:
test_loss, test_acc, preds, labels = test_model(trained_model, test_loader, criterion)


Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.55it/s]

Test Loss: 1.2013, Test Accuracy: 0.4592



