In [None]:
from models.efficientnet import EfficientNet

from data.dataset import create_dataloader
from data.transform import get_train_transform, get_validation_transform

from losses.poly_loss import PolyLoss
from utils.plots import plot_stats, plot_roc_curve
from train import train, validate, predict

import torch
from torch import optim

import numpy as np
from sklearn.metrics import classification_report

import os

# Dataloaders

In [None]:
train_data_dir = "dataset/train"
val_data_dir = "dataset/validation"
test_data_dir = "dataset/test"

image_size = (256, 256)
mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]

batch_size = 8
num_workers = 2

In [None]:
train_dataloader = create_dataloader(data_dir = train_data_dir, 
                                     batch_size = batch_size, 
                                     image_size = image_size, 
                                     mean = mean, 
                                     std = std, 
                                     num_workers = 2, 
                                     train = True)

val_dataloader = create_dataloader(data_dir = val_data_dir, 
                                     batch_size = batch_size,
                                     image_size = image_size,
                                     mean = mean,
                                     std = std,
                                     train = False)

test_dataloader = create_dataloader(data_dir = test_data_dir, 
                                     batch_size = batch_size, 
                                     image_size = image_size, 
                                     mean = mean, 
                                     std = std, 
                                     train = False)

# Model

In [None]:
# training params
learning_rate = 0.001
epochs = 3

# model params
efficientnet_version = 'b3'
pretrained = True
num_classes = 2

weights_folder = "weights"
os.makedirs(weights_folder, exist_ok = True)

In [None]:
model = EfficientNet(version = efficientnet_version,
                     pretrained = pretrained,
                     num_classes = num_classes)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

total_trainable_params = sum([p.numel() for p in model.parameters() if p.requires_grad])

print("Trainable params:", total_trainable_params)

In [None]:
optimizer = optim.Adam(model.parameters(),
                       lr = learning_rate)

criterion = PolyLoss(softmax=True)

# Train

In [None]:
history = train(epochs = epochs, 
                model = model, 
                train_dataloader = train_dataloader, 
                val_dataloader = val_dataloader, 
                optimizer = optimizer, 
                criterion = criterion, 
                device = device,
                save_folder = weights_folder)

train_loss = history["train_loss"] 
valid_loss = history["val_loss"]
train_acc = history["train_acc"]
valid_acc = history["val_acc"]

### Plot train / val metrics

In [None]:
plot_stats(range(epochs), train_loss, valid_loss, train_acc, valid_acc)

### ROC curve

In [None]:
# predict
preds, gts = predict(model, test_dataloader, device, return_probs = True)
preds, gts = np.array(preds), np.array(gts)

# plot ROC curve
plot_roc_curve(gts, preds[:,0], pos_label = 0)

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

ConfusionMatrixDisplay.from_predictions(gts, 
                                        np.argmax(preds, axis = 1), 
                                        display_labels = ['cell', 'no_cell'])

### Classification report

In [None]:
print(classification_report(gts, 
                            np.argmax(preds, axis = 1), 
                            target_names=['cell', 'no_cell'], 
                            digits=3))

# Load trained model

In [None]:
model = EfficientNet(version = 'b3',
                             pretrained = True,
                             num_classes = 2)

checkpoint = torch.load('weights/best_model_on_val.pth')
model.load_state_dict(checkpoint['model'])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

### Validate

In [None]:
loss, acc = validate(model, test_dataloader, criterion, device)

print('Accuracy: {:.4f}'.format(acc)) 
print('Loss: {:.4f}'.format(loss))

### ROC curve

In [None]:
# predict
preds, gts = predict(model, test_dataloader, device, return_probs = True)
preds, gts = np.array(preds), np.array(gts)

# plot ROC curve
plot_roc_curve(gts, preds[:,0], pos_label = 0)

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

ConfusionMatrixDisplay.from_predictions(gts, 
                                        np.argmax(preds, axis = 1), 
                                        display_labels = ['cell', 'no_cell'])

### Classification report

In [None]:
print(classification_report(gts, 
                            np.argmax(preds, axis = 1), 
                            target_names=['cell', 'no_cell'], 
                            digits=3))