In [None]:
%pip install medmnist

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models
import random

import medmnist
from medmnist import INFO, Evaluator

In [None]:
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

In [None]:
data_flag = 'breastmnist'
download = True

BATCH_SIZE = 32

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

# Deliverable 1

## First, we read the MedMNIST data, preprocess them and encapsulate them into dataloader form.

In [None]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)
validation_dataset = DataClass(split='val', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
validation_loader = data.DataLoader(dataset=validation_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
print(train_dataset)
print("===================")
print(test_dataset)

In [None]:
# montage

train_dataset.montage(length=20)

## Define seeing to make model results deterministic

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(0)

## Define the standard ResNet18 network to be compared in performance to the modified network

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class StandardNetwork(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)

        # Freeze all layers of the original ResNet18 model
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, n_classes)

    def forward(self, x):
        x = self.resnet(x)    # Pass input through ResNet18
        return x

# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

network = StandardNetwork(n_classes=n_classes)
network.to(device)

## Define Modified ResNet18 - a simple model for illustration

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ExtendedNetwork(nn.Module):
    def __init__(self, n_classes=2):
        # super(ExtendedNetwork, self).__init__()
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
       # self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # Freeze all layers of the original ResNet18 model
        num_features = self.resnet.fc.in_features
        # self.resnet.fc = nn.Linear(num_features, 512)
        self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.resnet.fc = nn.Linear(num_features, n_classes)
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(512, n_classes),
        )

    def forward(self, x):
        x = self.resnet(x)    # Pass input through ResNet18
        return x
    

# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

network = ExtendedNetwork(n_classes=n_classes)
network.to(device)

## Install Optuna for hyperparameter selection

In [None]:
%pip install optuna

## Defining modified test function to be used in training

In [None]:
def test(split):
    # Load the model with the best accuracy
    network.load_state_dict(torch.load('best_model.pth'))
    network.eval()
    y_true = torch.tensor([], device=device)
    y_score = torch.tensor([], device=device)
    collected_inputs = torch.tensor([], device=device)
    collected_targets = torch.tensor([], device=device)

    data_loader = train_loader_at_eval if split == 'train' else test_loader
    if split == "val":
        data_loader = validation_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = network(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1).squeeze()

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)
            collected_inputs = torch.cat((collected_inputs, inputs), 0)
            collected_targets = torch.cat((collected_targets, targets), 0)

        y_true = y_true.cpu().numpy()
        y_score = y_score.cpu().detach().numpy()
        collected_inputs = collected_inputs.detach().cpu()
        collected_targets = collected_targets.detach().cpu()

        evaluator = Evaluator(data_flag, split)
        auc, acc = evaluator.evaluate(y_score)

        # print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

        return (split, auc, acc), y_true, y_score

# print('==> Evaluating ...')
# train_inputs, train_targets, train_true, train_scores = test('train')
# test_inputs, test_targets, test_true, test_scores = test('test')

## Defining the accuracy function

In [None]:
def accuracy(outputs, targets):
    pred = outputs.argmax(dim=1, keepdim=True)
    correct = pred.eq(targets.view_as(pred)).sum().item()
    return correct / len(targets)

## Next, we can start to train and evaluate!

In [None]:
def validate_model(network, valid_loader, criterion, device):
    network.eval()  # Set the network to evaluation mode
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for validation_inputs, validation_targets in valid_loader:
            validation_inputs, validation_targets = validation_inputs.to(device), validation_targets.to(device)
            output = network(validation_inputs)
            validation_targets = validation_targets.squeeze().long()
            loss = criterion(output, validation_targets)
            total_loss += loss.item() * validation_inputs.size(0)  # Multiplying by batch size
            pred = output.argmax(dim=1, keepdim=True)  # Getting the index of the max log-probability
            total_correct += pred.eq(validation_targets.view_as(pred)).sum().item()
            total_samples += validation_inputs.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100  # percentage accuracy
    return avg_loss, accuracy

## Training code for the standard ResNet18 model : StandardNetwork

In [None]:
import optuna
from torch.optim import lr_scheduler

results_dict = {}

train_epoch_accuracies = []
test_epoch_accuracies = []

best_model_path = None

optimizer = optim.Adam(network.parameters(), lr=1e-5)

# define EPOCHS for optuna
NUM_EPOCHS = 20
best_validation_loss = float("inf")

# train
for epoch in range(NUM_EPOCHS):
    network.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    for tr_inputs, tr_targets in tqdm(train_loader):
        tr_inputs = tr_inputs.to(device)
        tr_targets = tr_targets.to(device)
        optimizer.zero_grad()
        outputs = network(tr_inputs)

        if task == 'multi-label, binary-class':
            tr_targets = tr_targets.to(torch.float32)
            loss = criterion(outputs, tr_targets)
            predicted = outputs > 0.5
        else:
            tr_targets = tr_targets.squeeze().long()
            loss = criterion(outputs, tr_targets)
            _, predicted = torch.max(outputs, 1)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct_predictions += (predicted == tr_targets).sum().item()
        total_predictions += tr_targets.size(0)
        
    print(loss)
    epoch_accuracy = 100 * correct_predictions / total_predictions
    train_epoch_accuracies.append(epoch_accuracy)

# validate
network.eval()


validation_loss, validation_accuracy = validate_model(network, validation_loader, criterion, device)

# Outputting the validation loss and accuracy
print('Validation - Loss: {:.6f}, Accuracy: {:.2f}%'.format(validation_loss, validation_accuracy))

print()



## Training code for the MODIFIED ResNet18 model FOR COMPARISON : ExtendedNetwork

In [None]:
import optuna
from torch.optim import lr_scheduler

results_dict = {}

train_epoch_accuracies = []
test_epoch_accuracies = []

best_model_path = None

def objective(trial):
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-3)
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2)
    patience = trial.suggest_int('patience', 5, 10)

    optimizer = optim.Adagrad(network.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=patience, threshold=0.0001, threshold_mode="abs")
    
    # define EPOCHS for optuna
    NUM_EPOCHS = 5
    best_validation_loss = float("inf")

    print("TRIAL:", trial.number)
    # train
    for epoch in range(NUM_EPOCHS):
        network.train()
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for train_inputs, train_targets in tqdm(train_loader):
            train_inputs = train_inputs.to(device)
            train_targets = train_targets.to(device)
            optimizer.zero_grad()
            outputs = network(train_inputs)

            if task == 'multi-label, binary-class':
                train_targets = train_targets.to(torch.float32)
                loss = criterion(outputs, train_targets)
                predicted = outputs > 0.5
            else:
                train_targets = train_targets.squeeze().long()
                loss = criterion(outputs, train_targets)
                _, predicted = torch.max(outputs, 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            correct_predictions += (predicted == train_targets).sum().item()
            total_predictions += train_targets.size(0)
            
        print(loss)
        epoch_accuracy = 100 * correct_predictions / total_predictions
        train_epoch_accuracies.append(epoch_accuracy)

    # validate
    network.eval()


    validation_loss, validation_accuracy = validate_model(network, validation_loader, criterion, device)
    
    # Outputting the validation loss and accuracy
    print('Validation - Loss: {:.6f}, Accuracy: {:.2f}%'.format(validation_loss, validation_accuracy))



    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        torch.save(network.state_dict(), "best_model.pth")
        print(f"Model saved with validation loss: {validation_loss}")

    
    # Scheduler update
    scheduler.step(validation_loss)

    train_result, train_y_true, train_y_score = test("train")
    val_result, val_y_true, val_y_score = test("val")
    test_result, test_y_true, test_y_score = test("test")

    print('==> Evaluating ...')
    test('train')
    test('val')
    test('test')

    print('%s  auc: %.3f  acc:%.3f' % train_result)
    print('%s  auc: %.3f  acc:%.3f' % val_result)
    print('%s  auc: %.3f  acc:%.3f' % test_result)
    print()

    # appending the test accuracy on each epoch
    test_epoch_accuracies.append(test_result[2])

    if test_result[1] >= 0.901 and test_result[2] >= 0.863:
        # Updating results_dict and set trial attribute
        results_dict = {'Trial': trial.number, 'Split': test_result[0], 'AUC': test_result[1], 'Accuracy': test_result[2], 'Learning Rate': learning_rate, 'Scheduler': scheduler, 'Weight Decay': weight_decay, 'Patience': patience, "y-true": test_y_true, "y-score": test_y_score}
        trial.set_user_attr("results_dict", results_dict)
        print("Good model found for Trial", trial.number, "\n")
    else:
        print("Trial", trial.number, "did not meet AUC/ACC thresholds.\n")


     # Saving the results for plotting outside the objective function
    trial.set_user_attr("epoch_losses", train_epoch_accuracies)
    trial.set_user_attr("epoch_accuracies", test_epoch_accuracies)
    
    return best_validation_loss

study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler(seed=0))
study.optimize(objective, n_trials=100)

trial = study.best_trial
print("Best trial is :", trial)

print(f"Value: {trial.value}")
print("Parameters: ")
for k, v in trial.params.items():
    print(f"{k}: {v}")

## Print the results to check values retrieved above the threshold

In [None]:
for trial in study.trials:
    if "results_dict" in trial.user_attrs:
        print(f"Trial {trial.number} Results:")
        print(trial.user_attrs["results_dict"])


## Getting trial with best results

In [None]:
best_trial = None
max_auc = 0  # Initializing max AUC to the lowest possible value at the start
max_acc = 0  # Initializing max Accuracy to the lowest possible value at the starrt

for trial in study.trials:
    if "results_dict" in trial.user_attrs:
        # Retrieve AUC and Accuracy from the results_dict
        current_auc = trial.user_attrs["results_dict"].get('AUC', 0)
        current_acc = trial.user_attrs["results_dict"].get('Accuracy', 0)
        # Check if current AUC and Accuracy are greater than the maximum found so far
        if current_auc > max_auc or (current_auc == max_auc and current_acc > max_acc):
            max_auc = current_auc
            max_acc = current_acc
            best_trial = trial.user_attrs["results_dict"]

print("Best Trial:")
print(best_trial)
auc = best_trial.get('AUC')
acc = best_trial.get('Accuracy')

print()
print('%s  auc: %.3f  acc:%.3f' % (best_trial.get('Split'), auc, acc))

## Install torcheval for calculating metrics

In [None]:
%pip install --upgrade torcheval

# Deliverable 2

In [None]:
y_true_buffer = best_trial.get("y-true")
y_score_buffer = best_trial.get("y-score")

In [None]:
from torcheval.metrics import MulticlassAUPRC
import torch

_, train_true, train_scores = test('train')
# _, test_true, test_scores = test('test')

train_scores = torch.tensor(train_scores, device='cuda' if torch.cuda.is_available() else 'cpu')
train_true = torch.tensor(train_true, device='cuda' if torch.cuda.is_available() else 'cpu').long().squeeze()  # Ensure labels are long type
test_scores = torch.tensor(y_score_buffer, device='cuda' if torch.cuda.is_available() else 'cpu')
test_true = torch.tensor(y_true_buffer, device='cuda' if torch.cuda.is_available() else 'cpu').long().squeeze()
# # Assuming num_classes is the actual number of classes
num_classes = 2 # len(torch.unique(torch.tensor(train_true)))  # Update based on your labels

# # Initializing the Multiclass AUPRC metric

metric_train = MulticlassAUPRC(num_classes=n_classes)
metric_train.update(train_scores, train_true)
train_auprc_result = metric_train.compute()

metric_test = MulticlassAUPRC(num_classes=n_classes)
metric_test.update(test_scores, test_true)
# # Compute the final AUPRC
test_auprc_result = metric_test.compute()

print("Train AUPRC result:", train_auprc_result)
print("Test AUPRC result:", test_auprc_result)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

train_true = train_true.cpu()
train_scores = train_scores.cpu()
test_true = test_true.cpu()
test_scores = test_scores.cpu()

# train_precision = precision_score(train_true, train_scores.argmax(dim=1))
test_precision = precision_score(test_true, test_scores.argmax(dim=1))

# print("Train Precision Score:", train_precision)
print("Test Precision Score:", test_precision)

print()

# train_recall = recall_score(train_true, train_scores.argmax(dim=1))
test_recall = recall_score(test_true, test_scores.argmax(dim=1))

# print("Train Recall Score:", train_recall)
print("Test Recall Score:", test_recall)

print() 

# train_f1 = f1_score(train_true, train_scores.argmax(dim=1))
test_f1 = f1_score(test_true, test_scores.argmax(dim=1))

# print("Train F1 Score:", train_f1)
print("Test F1 Score:", test_f1)

# Deliverable 3

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt

train_fpr, train_tpr, train_thresholds = metrics.roc_curve(train_true, train_scores[:, 1])
train_roc_auc = metrics.auc(train_fpr, train_tpr)

test_fpr, test_tpr, test_thresholds = metrics.roc_curve(test_true, test_scores[:, 1])
test_roc_auc = metrics.auc(test_fpr, test_tpr)

plt.figure()
plt.plot(train_fpr, train_tpr, label=f'Train ROC curve (area = {train_roc_auc:.2f})')
plt.plot(test_fpr, test_tpr, label=f'Test ROC curve (area = {test_roc_auc:.2f})')
plt.plot([0, 1], [0, 1], "b--")  # Dashed diagonal
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

In [None]:
from sklearn.metrics import confusion_matrix
confusion_matrix(test_true,  test_scores.argmax(dim=1))

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

threshold = 0.5
# Assuming test_scores[:, 1] is the probability of the positive class
binary_test_scores = (test_scores[:, 1] >= threshold).numpy().astype(int)
confusionMatrix = confusion_matrix(test_true.numpy(), binary_test_scores)


fig, ax = plt.subplots()
labels = ['Malignant', 'Benign']

display_confusion_matrix = ConfusionMatrixDisplay(confusion_matrix=confusionMatrix, display_labels=labels)
display_confusion_matrix.plot(cmap="Blues", ax=ax)

ax.set_xticklabels(['Predicted: Malignant', 'Predicted: Benign'])
ax.set_yticklabels(['Actual: Malignant', 'Actual: Benign'])


# Adjusting the confusion matrix annotations for TP, FN, FP, TN
ax.texts[0].set_text(f'TN: {confusionMatrix[0, 0]}')
ax.texts[1].set_text(f'FP: {confusionMatrix[0, 1]}')
ax.texts[2].set_text(f'FN: {confusionMatrix[1, 0]}')
ax.texts[3].set_text(f'TP: {confusionMatrix[1, 1]}')

plt.show()


# Deliverable 4

In [None]:
from sklearn.metrics import roc_auc_score

def train_model(train_loader, model, criterion, optimizer):
    network.train()
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        targets = targets.squeeze().long()
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()


def evaluate_model(loader, model, criterion):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            probabilities = torch.softmax(outputs, dim=1)[:, 1]
            _, predicted = torch.max(outputs.data, 1)
            total_correct += (predicted == targets).sum().item()
            total_samples += targets.size(0)
            all_preds.extend(probabilities.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    
    accuracy = total_correct / total_samples
    auc = roc_auc_score(all_targets, all_preds)
    return accuracy, auc

In [None]:
print(best_trial.keys())

In [None]:
from sklearn.model_selection import train_test_split, KFold
from torch.utils.data import DataLoader, ConcatDataset, SubsetRandomSampler
import numpy as np
import torch.optim as optim
from torch import nn
import torch

# Hyperparameters and setup
num_epochs = 10
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

k_folds = 5

# Get the most optimal hyperparameters found when we did the hyperparameter tuning so we can use them here now for K Fold
learning_rate_optimal = best_trial.get("Learning Rate")
weight_decay_optimal = best_trial.get("Weight Decay")
patience_optimal = best_trial.get("Patience")


combined_dataset = ConcatDataset([train_dataset, test_dataset, validation_dataset])

# Define K Fold Cross Validation Setup
k_fold_classifier = KFold(n_splits=k_folds, shuffle=True, random_state=0)

accuracies, aucs = [], []

for fold, (trainval_idx, test_idx) in enumerate(k_fold_classifier.split(np.arange(len(combined_dataset)))):

    # Dividing the data 
    val_size = len(trainval_idx) // 4  # Since we are rotating this through 4 remaining folds

    # We can treate the folds as blocks of equal size
    validation_start_idx = val_size * (fold % 4)  # Rotate through the first four chunks
    validation_end_idx = validation_start_idx + val_size # End index for validation set

    # Create indices for validation and training
    val_idx = trainval_idx[validation_start_idx:validation_end_idx]
    train_idx = np.concatenate([trainval_idx[:validation_start_idx], trainval_idx[validation_end_idx:]])

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    train_loader = DataLoader(combined_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    val_loader = DataLoader(combined_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
    test_loader = DataLoader(combined_dataset, batch_size=BATCH_SIZE, sampler=test_sampler)

    # Initialize your model and optimizer here
    model = ExtendedNetwork(n_classes=2).to(device)
    optimizer = optim.Adagrad(model.parameters(), lr=learning_rate_optimal, weight_decay=weight_decay_optimal)
    criterion = nn.CrossEntropyLoss()

    # We train and validate the model now
    for epoch in range(num_epochs):
        train_model(train_loader, model, criterion, optimizer)
        val_acc, val_auc = evaluate_model(val_loader, model, criterion)

    # Evaluate on test set
    test_acc, test_auc = evaluate_model(test_loader, model, criterion)
    accuracies.append(test_acc)
    aucs.append(test_auc)
    print(f"Fold {fold+1}: Test Accuracy = {test_acc:.3f}, Test AUC = {test_auc:.3f}")

# Report average results
print(f"Average Test Accuracy: {np.mean(accuracies):.3f}")
print(f"Average Test AUC: {np.mean(aucs):.3f}")


# Estimation Testing

In [None]:
import torch
from PIL import Image

# Assume 'model' is your trained model
network.eval()

# For a specific image, we can load it directly as shown below
image, label = pil_dataset[108]

# Transforming the data
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    # Include normalization if used during training
])

# Apply transformation
image = transform(image).unsqueeze(0)  # Add batch dimension

# Move the image to the same device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image = image.to(device)
model = network.to(device)

with torch.no_grad():
    output = model(image)

# We assume the output is logits; so we can apply softmax for probabilities
probabilities = torch.softmax(output, dim=1).cpu().numpy().flatten()

# Assuming class 0 is 'benign' and class 1 is 'malignant'
classes = ['benign', 'malignant']
predicted_class = classes[probabilities.argmax()]
confidence = probabilities.max()

print(f"Predicted class: {predicted_class} with confidence {confidence:.2f}")


image_np = image.squeeze().cpu().numpy()

# If the image is grayscale (C, H, W) where C = 1, we convert it to (H, W) for matplotlib
if image_np.shape[0] == 1:  # Grayscale image, single channel
    image_np = image_np.squeeze(0)  # Now shape is (H, W)
elif image_np.shape[0] == 3:  # If it's a 3-channel image
    # Convert from (C, H, W) to (H, W, C) for RGB images
    image_np = np.transpose(image_np, (1, 2, 0))

# Display the image
plt.imshow(image_np, cmap='gray' if image_np.ndim == 2 else None)
plt.axis('off')
plt.show()

# References

# To do

1. add val auc and acc DONE
2. k fold DONE
3. loss / accuracy graphs
4. show basic resnet18
5. test on another device