In [2]:
import matplotlib
import matplotlib.pyplot as plt
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
from sklearn.metrics import ConfusionMatrixDisplay

from datasets import CLASSES
from model import Resnet50Model, EfficientPatchNet
from train import train, validate
from datasets import get_datasets, get_data_loaders
from utils import calculate_accuracy, save_model, save_plots
from optuna_hp import objective

In [None]:
# now we can run the experiment
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(study_name='patch-mri-classification', direction='maximize', sampler=sampler)
study.optimize(objective, n_trials=50, timeout=10000)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
# Load the training and validation datasets.
dataset_train, dataset_valid, dataset_test, dataset_classes = get_datasets()
print(f"[INFO]: Number of training images: {len(dataset_train)}")
print(f"[INFO]: Number of validation images: {len(dataset_valid)}")
print(f"[INFO]: Class names: {dataset_classes}\n")
batch_size = 32
# Load the training and validation data loaders.
train_loader, valid_loader, test_loader = get_data_loaders(
    dataset_train, dataset_valid, dataset_test, batch_size,
)

# Learning_parameters. 
lr = 1e-3  # Hyperparameter
epochs = 100  # Hyperparameter
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Computation device: {device}")
print(f"Learning rate: {lr}")
print(f"Epochs to train for: {epochs}\n")
in_channels = 49  # Hyperparameter
model = Resnet50Model(in_channels=in_channels, num_classes=4).to(device)

# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Optimizer.
optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.001)
# Scheduler
scheduler = CosineAnnealingLR(optimizer, epochs, verbose=True)
# scheduler = MultiStepLR(optimizer, milestones=[5, 10, 15, 20], gamma=0.1)
# Loss function.
criterion = nn.CrossEntropyLoss()

# Lists to keep track of losses and accuracies.
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
# Start the training.
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(
        model, train_loader,optimizer, criterion, device,
    )
    valid_epoch_loss, valid_epoch_acc = validate(
        model, valid_loader, criterion, device,
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    print('-' * 50)
    time.sleep(5)
    scheduler.step()


In [None]:
# Save the trained model weights.
save_model(epochs, model, optimizer, criterion)
# Save the loss and accuracy plots.
save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')

In [None]:
test_accuracy, confusion_matrix = calculate_accuracy(model, test_loader, device)
print("test accuracy: {:.3f}%".format(test_accuracy))

matplotlib.style.use('default')
# plot confusion matrix
num_classes = len(CLASSES)
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.matshow(confusion_matrix, aspect='auto', vmin=0, vmax=410, cmap=plt.get_cmap('Blues'))
plt.ylabel('Actual Category')
plt.yticks(range(num_classes), CLASSES)
plt.xlabel('Predicted Category')
plt.xticks(range(num_classes), CLASSES)
plt.show()

In [None]:
test_accuracy, confusion_matrix = calculate_accuracy(model, test_loader, device)
print("test accuracy: {:.3f}%".format(test_accuracy))

matplotlib.style.use('default')
cm_display = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=CLASSES)
cm_display.plot(cmap=plt.cm.Blues)
plt.show()