# This is an example for the training of the Efficientnetb2

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import os
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from datetime import datetime
from torchsummary import summary

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

In [None]:
def get_data(mode="dev", percentage=0.3):
    data = pd.read_csv('./train.csv')
    if mode == "prod":
        print("Prod mode used, all data used")
        return data
    else:
        print(f"Dev mode used, percentage of the data used: {percentage}")
        return data.sample(frac=percentage)

In [None]:
batch_size = 32
epochs = 20
device = get_device()

image_size = 260
target_size = 11
target_cols = ['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal',
                   'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal',
                   'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal',
                   'Swan Ganz Catheter Present']
data = get_data(mode="prod")

## DatasetTransformer

This the class that will do the Data Augmentation and will transform the inputs into tensors

In [None]:
class DatasetTransformer(torch.utils.data.Dataset):
    def __init__(self, df, augmented=False, transform=None):
        if transform is None:
            if augmented:
                self.transform = A.Compose([
                    A.Resize(image_size, image_size),
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    ToTensorV2()
                ])
            else:
                self.transform = A.Compose([
                    A.Resize(image_size, image_size),
                    ToTensorV2()
                ])
        else:
            self.transform = transform
        self.df = df
        self.labels = self.df[target_cols].values
        self.file_names = df['StudyInstanceUID'].values
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_name = self.file_names[idx]
        image_path = f'./train/{image_name}.jpg'
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        image_t = self.transform(image=image)['image']
        label = self.labels[idx]
        data = { 
            "image": image_t.float(), 
            "targets": torch.tensor(self.labels[idx]).float()
        }
        
        return data

## Model

We use the library timm that holds multiple pre-trained models and create an interface that is the same for all different models

In [None]:
class EfficientNet(nn.Module):
    def __init__(self, model_name, pretrained=False):
        super().__init__()
        
        model = timm.create_model(model_name, pretrained=pretrained)
            
        n_features = model.classifier.in_features
        
        self.model = nn.Sequential(*list(model.children())[:-1])
        self.drop_out = nn.Dropout(p=0.5)
        self.fc = nn.Linear(n_features, 11)
        
    def forward(self, x):
        x = self.model(x)
        x = self.drop_out(x)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

In [None]:
def train(model, data_loader, loss_function, optimizer, weight_pos, weight_neg, device):
    """
        It trains the model for one epoch
    :param model: model we need to train
    :param data_loader: Data loader (iterable)
    :param loss_function: Loss Function
    :param optimizer: Optimizer (e.g Adam)
    :param device: "cpu" or "cuda"
    :return: None
    """
    total_number = 0.0
    total_loss = 0.0
    mean_auc = []
    model.train()
    for i, data in enumerate(data_loader):
        print("Training: batch ", i, end="\r")
        image_input = data["image"].to(device)
        targets = data["targets"].to(device)
        outputs = model(image_input)
        
        loss = loss_function(outputs, targets)
        
        total_number += image_input.shape[0]
        total_loss += image_input.shape[0] * loss.item()
        mean_auc.append(mean_roc_auc(targets.detach().cpu(), outputs.detach().cpu()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return total_loss / total_number, np.average(mean_auc)

In [None]:
def mean_roc_auc(targets, ouputs):
    roc_auc = []
    for k in range(11):
        try:
            roc_auc.append(metrics.roc_auc_score(targets[:, k], ouputs[:, k])) # it computes the AUC ROC metrics for each label and then averages it
        except Exception as e:
            roc_auc.append(0.5)
            pass
    return np.nanmean(roc_auc)

In [None]:
def test(model, data_loader, loss_function, weight_pos, weight_neg, device):
    with torch.no_grad():
        model.eval()
        total_number = 0
        total_loss, correctly_predicted = 0.0, 0.0
        mean_auc = []
        for i, data in enumerate(data_loader):
            print("Test: batch ", i, end="\r")
            image_input = data["image"].to(device)
            targets = data["targets"].to(device)
            outputs = model(image_input)
            total_number += image_input.shape[0]
            
            total_loss += image_input.shape[0] * loss_function(outputs, targets).item()
            mean_auc.append(mean_roc_auc(targets.cpu(), outputs.cpu()))
    return total_loss / total_number, np.average(mean_auc)

In [None]:
train_validation_frac = 0.9 # We split the training data into 90% training and 10% validation
train_data = data.sample(frac=train_validation_frac) 
validation_data = data.drop(train_data.index)

train_dataset = DatasetTransformer(train_data, augmented=True)
validation_dataset = DatasetTransformer(validation_data)
train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
validation_dataset_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=True)

In [None]:
model = EfficientNet("efficientnet_b2", pretrained=True)
model.to(device) # We send the model to the GPU (if available) to make the computation faster

In [None]:
# Allows to visualize the architecture of the model
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters {pytorch_total_params}")
summary(model, (3, image_size, image_size))

In [None]:
learning_rate = 1e-3
weight_decay = 0.0

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, verbose=True, patience=3, mode='min', threshold=2*1e-2)
loss = nn.BCELoss()

In [None]:
class GraphUpdater():
    """
        Class that allows to save the loss and accuracy for different epoch and then compute the graph at the end of the training
    """
    def __init__(self, type, name=None):
        self.type = type
        if name is None:
            name = self.type + "_loss_auc_" + str(int(datetime.timestamp(datetime.now())))
        self.filepath = os.path.join("./", "results", "epoch_model")
        self.name = name
        self.loss = []
        self.accuracy = []
        self.epoch = []
    
    def update(self, loss, accuracy):
        self.loss.append(loss)
        self.accuracy.append(accuracy)
        self.epoch.append(len(self.epoch))
        df = pd.DataFrame(data={"loss": self.loss, "accuracy": self.accuracy, "epoch": self.epoch})
        df.to_csv(os.path.join(self.filepath, self.name + ".csv"), index_label="epoch")
        
    def display(self):
        df = pd.DataFrame(data={"loss": self.loss, "accuracy": self.accuracy, "epoch": self.epoch})
        df.plot(x="epoch", y=["loss", "accuracy"], title="["+self.type + "]" + " Loss and Accuracy per epoch")
        plt.show()
        plt.savefig(os.path.join(self.filepath, self.name))

In [None]:
class ModelCheckpoint:
    """
        Class that allows to save the best model 
    """
    def __init__(self, model, filename=None, filepath=None):
        self.min_loss = None
        self.model = model

        if filepath is None:
            filepath = os.path.join("./", "results", "best_model")

        if filename is None:
            filename = "best_model_efficientnet_b7_v3.pt"

        self.filepath = os.path.join(filepath, filename)
        
    def update(self, loss):
        if (self.min_loss is None) or (loss < self.min_loss):
            print(f"Saving a better model here: {self.filepath}", end='\n')
            torch.save(self.model.state_dict(), self.filepath)
            self.min_loss = loss

In [None]:
train_updater = GraphUpdater(type="Train")
validation_updater = GraphUpdater(type="Validation")
model_checkpoint = ModelCheckpoint(model)

In [None]:
print(f"Training on {len(train_dataset)} images")
for t in range(epochs):
        print(f'Epoch: {t}')

        train_loss, train_auc = trainv2(model=model, data_loader=train_dataset_loader, loss_function=loss, optimizer=optimizer, weight_pos=freq_pos, weight_neg=freq_neg, device=device)
        train_updater.update(**{"loss": train_loss, "accuracy": train_auc})
        
        print(f'Training step: Loss: {train_loss}, AUC: {train_auc}', end='\n')
        
        val_loss, val_auc = testv2(model=model, data_loader=validation_dataset_loader,
                                 loss_function=loss, weight_pos=freq_pos, weight_neg=freq_neg, device=device)
        validation_updater.update(**{"loss": val_loss, "accuracy": val_auc})
        
        print(f'Validation step: Loss: {val_loss}, AUC: {val_auc}', end='\n')
        
        model_checkpoint.update(loss=val_loss)
        torch.save(model.state_dict(),f'./results/epoch_model/checkpoint_epoch_{t}.pt')
        print('Model saved')
        scheduler.step(val_loss)
        
train_updater.display()
validation_updater.display()