In [1]:
pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
import monai
from monai.data import DataLoader, decollate_batch, CacheDataset
from monai.metrics import ROCAUCMetric
from monai.data import DataLoader, ImageDataset
from monai.transforms import Activations, AsDiscrete, Compose
import os, random, time, calendar, datetime, warnings
from sys import platform
import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import warnings
from monai.networks.nets import DenseNet121
import matplotlib.pyplot as plt
import random
import torch.nn.functional as F
import monai
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score,  roc_auc_score, classification_report
from sklearn.metrics import classification_report

def set_determinism(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_determinism(seed=10)

"""
Create a dictionary {'image': link, 'label': label} specifying the channels (T1w, T2w)
"""

def obtain_channel_data(df, channel, classes):

    label_map = {
        'AD': 0,
        'CN': 1,
        'cMCI': 2,
        'ncMCI': 3
    }

    df['labels'] = df['label'].map(label_map)
    df_train = df[df['set'] == 'train']
    df_val = df[df['set'] == 'val']
    df_test = df[df['set'] == 'test']
    
    column = f'{channel}_path'
    
    df_train = df_train[[column, 'labels']].copy()
    df_train.rename(columns={column: 'image'}, inplace=True)
    df_train.rename(columns={'labels': 'label'}, inplace=True)
    df_train['image'] = df_train['image'].str.replace(r'\\', '/', regex=False)

    df_val = df_val[[column, 'labels']].copy()
    df_val.rename(columns={column: 'image'}, inplace=True)
    df_val.rename(columns={'labels': 'label'}, inplace=True)
    df_val['image'] = df_val['image'].str.replace(r'\\', '/', regex=False)

    df_test = df_test[[column, 'labels']].copy()
    df_test.rename(columns={column: 'image'}, inplace=True)
    df_test.rename(columns={'labels': 'label'}, inplace=True)
    df_test['image'] = df_test['image'].str.replace(r'\\', '/', regex=False)

    
    df_train['label'] = df_train['label'].apply(lambda x: torch.nn.functional.one_hot(torch.as_tensor(x), num_classes=classes).float())
    df_val['label'] = df_val['label'].apply(lambda x: torch.nn.functional.one_hot(torch.as_tensor(x), num_classes=classes).float())
    df_test['label'] = df_test['label'].apply(lambda x: torch.nn.functional.one_hot(torch.as_tensor(x), num_classes=classes).float())
    
    df_train=df_train.reset_index()
    df_val=df_val.reset_index()
    df_test=df_test.reset_index()

    return df_train, df_val, df_test
                
scan_type = 'T2w'
task = 'ternary1' # [all, binary1, binary2, ternary1, ternary2]

classes=4
data=pd.read_csv('/kaggle/input/oasis-final/metadata_updated.csv')  

if task =='binary1':
    data=data[(data['label'] == 'CN') | (data['label'] == 'AD')]
    classes=2
elif task =='ternary1':
    data=data[(data['label'] == 'CN') | (data['label'] == 'AD') | (data['label'] == 'cMCI')]
    classes=3
elif task =='ternary2':
    data=data[(data['label'] == 'CN') | (data['label'] == 'cMCI') | (data['label'] == 'ncMCI')]
    classes=3    
elif task =='binary2':
    data=data[(data['label'] == 'cMCI') | (data['label'] == 'ncMCI')]
    classes=2   
elif task == 'all':
    classes=4

train_set, val_set, test_set = obtain_channel_data(data, scan_type, classes)
pin_memory = torch.cuda.is_available()

train_ds = ImageDataset(image_files=train_set['image'], labels=train_set['label'])
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4, pin_memory=pin_memory)

val_ds = ImageDataset(image_files=val_set['image'], labels=val_set['label'])
val_loader = DataLoader(val_ds, batch_size=8, shuffle=True, num_workers=4, pin_memory=pin_memory)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

def train_model(model, train_loader, val_loader, loss_function, optimizer, val_interval, max_epochs, early_stopping, device):
    saved_path = '/kaggle/working/'
    reports_path = '/kaggle/working/'
    reports_file = '/kaggle/working/results.csv'
    logs_path = '/kaggle/working/'
    results=[]
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = [[], []]
    metric_values = []
    writer = SummaryWriter()

    print('Device currently active: ', device)
    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss_train = 0
        epoch_loss_eval = 0
        step = 0
        step_eval = 0

        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss_train += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            if step % 50 == 0:
                print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

        epoch_loss_train /= step
        epoch_loss_values[0].append(epoch_loss_train)
        print(f"epoch {epoch + 1} average loss: {epoch_loss_train:.4f}")

        # Validation
        if (epoch + 1) % val_interval == 0:
            model.eval()
            num_correct = 0.0
            metric_count = 0
            all_preds = []
            all_labels = []
            all_probs=[]
            for val_data in val_loader:
                step_eval += 1
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                with torch.no_grad():
                    val_outputs = model(val_images)

                    all_preds.append(val_outputs.argmax(dim=1).cpu().numpy())
                    all_labels.append(val_labels.argmax(dim=1).cpu().numpy())  # true labels
                    all_probs.append(F.softmax(val_outputs, dim=1).cpu().numpy())  # Store probabilities

                    value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
                    metric_count += len(value)
                    num_correct += value.sum().item()
                    val_loss = loss_function(val_outputs, val_labels)
                    epoch_loss_eval += val_loss.item()

            epoch_loss_eval /= step_eval
            epoch_loss_values[1].append(epoch_loss_eval)


            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
            all_probs = np.concatenate(all_probs)

            accuracy = accuracy_score(all_labels, all_preds)
            precision = precision_score(all_labels, all_preds, average='macro',zero_division=0)
            recall = recall_score(all_labels, all_preds, average='macro',zero_division=0)
            f1 = f1_score(all_labels, all_preds, average='macro',zero_division=0)

            if classes==2:
                auc = roc_auc_score(all_labels, all_probs[:, 1])
            else:
                auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')

            if auc > best_metric:
                best_metric = auc
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_classification3d_array.pth")
                print("saved new best metric model")

            print(f"Current epoch: {epoch+1} current auc: {auc:.4f}")
            print(f"Best auc: {best_metric:.4f} at epoch {best_metric_epoch}")
            print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, AUC: {auc:.4f}")
            writer.add_scalar("val_accuracy", accuracy, epoch + 1)
            writer.add_scalar("precision", precision, epoch + 1)
            writer.add_scalar("recall", recall, epoch + 1)
            writer.add_scalar("f1_score", f1, epoch + 1)
            writer.add_scalar("AUC", auc, epoch + 1)

        if epoch + 1 - best_metric_epoch == early_stopping:
            print(f"\nEarly stopping triggered at epoch: {str(epoch + 1)}\n")
            break

    print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

    results.append({
        'epoch': epoch + 1,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'AUC': auc,
        'train_loss': epoch_loss_train,
        'val_loss': epoch_loss_eval,
        'channel': scan_type,
        'multiclass': task
    })

    if os.path.isfile(reports_file):
        existing_df = pd.read_csv(reports_file)
        new_df = pd.DataFrame(results)
        updated_df = pd.concat([existing_df, new_df], ignore_index=True)
    else:
        updated_df = pd.DataFrame(results)
    updated_df.to_csv(reports_file, index=False)
    writer.close()
    return results, epoch_loss_values

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=classes).to(device)

cross_entropy_loss = torch.nn.CrossEntropyLoss()
adam_optimizer = torch.optim.Adam(model.parameters(),lr=1e-4, weight_decay=1e-5)

results=[]
epoch_loss_values = [[], []]

results, epoch_loss_val= train_model(model,
                                    train_loader=train_loader,
                                    val_loader=val_loader,
                                    loss_function=cross_entropy_loss, 
                                    optimizer=adam_optimizer, 
                                    val_interval=1, 
                                    max_epochs=50, 
                                    early_stopping=10, 
                                    device=device)

In [None]:
# Print a graph representing loss function of training and evaluation

train_losses = epoch_loss_val[0]
val_losses = epoch_loss_val[1]
epochs = range(1, len(train_losses) + 1)
plt.figure(figsize=(6,4))
plt.plot(epochs, train_losses, 'r', label='Train Loss') 
plt.plot(epochs, val_losses, 'b', label='Validation Loss') 
plt.xticks(range(1, len(epochs) + 1, 2))
plt.title('Train and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Evaluation phase

model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=classes).to(device)
model_path = '/kaggle/input/20/pytorch/default/1/20.pth'
model.load_state_dict(torch.load(model_path))

test_ds = ImageDataset(image_files=test_set['image'], labels=test_set['label'])
test_loader = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=pin_memory)

model.eval()

all_preds = []
all_labels = []
all_probs=[]
with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        test_outputs = model(test_images)
        preds = test_outputs.argmax(dim=1) 
        all_probs.append(F.softmax(test_outputs, dim=1).cpu().numpy())
        all_preds.append(preds.cpu().numpy())
        all_labels.append(test_labels.argmax(dim=1).cpu().numpy())  
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)  



print("Test Classification Report:")
print(classification_report(all_labels, all_preds, zero_division=0))
if classes==2:
    auc = roc_auc_score(all_labels, all_probs[:, 1])
else:
    auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
print('AUC:',auc)


# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion matrix')
plt.show()