In [None]:
! pip install medmnist
! pip install libauc==1.2.0 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# importing the necessarry Libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import medmnist
from medmnist import ChestMNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from libauc.losses import AUCMLoss
from libauc.optimizers import PESG
from libauc.models import resnet18 as ResNet18
from libauc.sampler import DualSampler
from libauc.metrics import auc_prc_score
from libauc.metrics import auc_roc_score
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
# Define the transforms for image preprocessing
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset
train_dataset = ChestMNIST(root='./', split='train', download=True, transform=data_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Load the validation set
val_dataset = ChestMNIST(root='./', split='val',download=True, transform=data_transforms)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

# Load the test dataset
test_dataset = ChestMNIST(root='./', split='test', download=True, transform=data_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to ./chestmnist.npz


100%|██████████| 82802576/82802576 [00:05<00:00, 14260383.47it/s]


Using downloaded and verified file: ./chestmnist.npz
Using downloaded and verified file: ./chestmnist.npz


In [None]:
train_dataset


Dataset ChestMNIST (chestmnist)
    Number of datapoints: 78468
    Root location: ./
    Split: train
    Task: multi-label, binary-class
    Number of channels: 1
    Meaning of labels: {'0': 'atelectasis', '1': 'cardiomegaly', '2': 'effusion', '3': 'infiltration', '4': 'mass', '5': 'nodule', '6': 'pneumonia', '7': 'pneumothorax', '8': 'consolidation', '9': 'edema', '10': 'emphysema', '11': 'fibrosis', '12': 'pleural', '13': 'hernia'}
    Number of samples: {'train': 78468, 'val': 11219, 'test': 22433}
    Description: The ChestMNIST is based on the NIH-ChestXray14 dataset, a dataset comprising 112,120 frontal-view X-Ray images of 30,805 unique patients with the text-mined 14 disease labels, which could be formulized as a multi-label binary-class classification task. We use the official data split, and resize the source images of 1×1024×1024 into 1×28×28.
    License: CC BY 4.0

In [None]:
# Define the model
model = ResNet18()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 14),
            nn.Sigmoid()
        )

SEED = 54321
BATCH_SIZE = 64 
epochs = 15
decay_epochs = [8, 13]
weight_decay = 0.001

# Set new hyperparameters
lr = 0.1  # Start Lower learning rate
margin = 1.0
loss_fn = AUCMLoss()

# Set optimizer and tunning  hyperparameters
optimizer = PESG(model, 
                 loss_fn=loss_fn, 
                 momentum=0.8,
                 margin=margin,
                 epoch_deacy = 0.03,
                 lr=lr,
                 weight_decay=weight_decay)

In [None]:

# Train the model
train_log = []
test_log = []
val_log = []

for epoch in range(epochs):
  
     if epoch in decay_epochs:
         optimizer.update_regularizer(decay_factor=10)

     train_loss = []
     train_pred_list = []
     train_true_list = []
     model.train()    
     for data, targets in train_dataloader:
         y_pred = model(data)
         loss = loss_fn(y_pred, targets)
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         train_loss.append(loss.item())

        # calculate training AUC
         train_pred_list.append(y_pred.cpu().detach().numpy())
         train_true_list.append(targets.numpy())

     train_true = np.concatenate(train_true_list)
     train_pred = np.concatenate(train_pred_list)
     train_auc = auc_roc_score(train_true, train_pred)

     avg_loss = sum(train_loss) / len(train_dataloader)
     #print(f"Epoch {epoch + 1} loss: {avg_loss}, train AUC: {train_auc[0]}")

  
    # Checking on the Validation set
     model.eval()
     val_pred_list = []
     val_true_list = []
     for val_data, val_targets in val_dataloader:
         val_pred = model(val_data)
         val_pred_list.append(val_pred.cpu().detach().numpy())
         val_true_list.append(val_targets.numpy())
     val_true = np.concatenate(val_true_list)
     val_pred = np.concatenate(val_pred_list)
     val_auc =  auc_roc_score(val_true, val_pred)
     #print('AUC-ROC Val:\n', val_auc[0])
    
    # Checking on the Test set
     test_pred_list = []
     test_true_list = []
     for test_image, test_labels in test_dataloader:
         test_output = model(test_image)
         test_pred_list.append(test_output.cpu().detach().numpy())
         test_true_list.append(test_labels.numpy())
     test_true = np.concatenate(test_true_list)
     test_pred = np.concatenate(test_pred_list)
     test_auc =  auc_roc_score(test_true, test_pred)
     #print('AUC-ROC Test:\n', val_auc[0]) 
     #print("\n") 
     print(f"Epoch {epoch + 1} loss: {avg_loss}, train AUC: {train_auc[0]}, Val AUC: {val_auc[0]}, Test AUC: {test_auc[0]}") 

     train_log.append(train_auc[0]) 
     test_log.append(test_auc[0])
     val_log.append(val_auc[0])

Epoch 1 loss: 0.19062988007970202, train AUC: 0.5081007590013648, Val AUC: 0.5060694219556003, Test AUC: 0.4982239333689301
Epoch 2 loss: 0.1892669957426977, train AUC: 0.5039363135350738, Val AUC: 0.5203637441492138, Test AUC: 0.5048260882492495
Epoch 3 loss: 0.18933444215780232, train AUC: 0.5066688292210585, Val AUC: 0.48979897185429, Test AUC: 0.45932233923982474
Epoch 4 loss: 0.1897853329554185, train AUC: 0.5075082159864699, Val AUC: 0.5164475884585777, Test AUC: 0.4920393273297976
Epoch 5 loss: 0.18937837720367398, train AUC: 0.5020680313692434, Val AUC: 0.48066957768162877, Test AUC: 0.4872212400782467


In [None]:
# Save the trained model
PATH = "./ChestModel.pth"
torch.save(model.state_dict(), PATH)

In [None]:
plt.plot(train_log, label='Training AUC')
plt.plot(val_log, label='Validation AUC')
plt.plot(test_log, label='Testing AUC')
plt.xlabel('Epochs')
plt.ylabel('AUC')
plt.legend()
plt.show()