# **Train Inception-Resnet-V2 two class classification**

* [Dependencies and imports](#section-one)
* [Basic configurations](#section-two)
    - [Overide check_box function](#sub-section-one-one)
    - [Check labels distribution](#sub-section-one-two)
    - [Approaches for imbalanced data](#sub-section-one-three)
* [Split data to folds](#section-three)
* [Data augmentation using Albumentations](#section-four)
* [Custom dataset](#section-five)
* [Fitter](#section-six)
* [Train](#section-seven)

<a id="section-one"></a>
## **Dependencies and imports**

In [None]:
conda install gdcm -c conda-forge

In [None]:
!pip install --upgrade --force-reinstall numpy

In [None]:
!pip install timm

In [None]:
!pip install torchmetrics timm

In [None]:
import torch
import torchvision
from torch import nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from glob import glob
import timm
import torchmetrics 
import matplotlib.pyplot as plt
# --- images --- 
import cv2
import albumentations as A
# --- time ---
from datetime import datetime
import time
# --- data ---
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from sklearn.model_selection import StratifiedShuffleSplit
# --- wandb ---
import wandb
from kaggle_secrets import UserSecretsClient
# --- dicom ---
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

In [None]:
OFFLINE = False

if not OFFLINE:
    user_secrets = UserSecretsClient()
    wandb_key = user_secrets.get_secret("wandb-key")
    wandb.login(key=wandb_key)

    run = wandb.init(project="siim-covid19-detection", name="2-class-classification", mode='online')

<a id="section-two"></a>
## **Basic configuration**

In [None]:
# --- configs ---
NONE = 'none'
OPACITY = 'opacity'

class Configs:
    img_size = 1024
    oversample = True
    n_folds = 5
    test_size = 0.15
    classes = {NONE:0, OPACITY:1}
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    batch_size = 4
    num_workers = 8

<a id="sub-section-one-one"></a>
### **Check labels distribution**

In [None]:
train_df = pd.read_csv('../input/d/miriamassraf/siim-covid19-detection/train_df.csv')
for cls in list(Configs.classes.keys()):
    print("number of samples for class \'{}\': {}".format(cls, len(train_df[train_df['image_level']==cls])))

<a id="sub-section-one-two"></a>
### **Approaches for imbalanced data**
1. Oversample - "create" new data for the less common class </br>
2. StratifiedShuffleSplit - balanced distribution of the data to folds </br>

<a id="section-three"></a>
## **Split data to folds**

In [None]:
train_df.head()

In [None]:
class DataFolds:
    def __init__(self, train_df, continue_train=False):
        assert Configs.n_folds > 0, "num folds must be a positive number"
        if continue_train:
            self.train_df = pd.read_csv('../input/d/miriamassraf/siim-covid19-detection/splitted_train_df.csv')
        else:
            self.train_df = train_df
            if Configs.oversample:
                # double the size of 'none' data (sample fraction of 1.0)
                self.oversample('none', 1.0)

            self.set_int_labels()
            self.split_to_folds(Configs.test_size)
    
    def oversample(self, cls, frac):
        rows_to_add = self.train_df[(self.train_df['image_level']==cls)&(self.train_df['study_level']=='negative')].sample(frac=frac, replace=True)
        self.train_df = self.train_df.append(rows_to_add, ignore_index = True)
        
    def set_int_labels(self):
        # set int labels for opacity and none
        for index, row in self.train_df.iterrows():
            if row['image_level'] == OPACITY:
                self.train_df.loc[index, 'int_label'] = Configs.classes[OPACITY]
            else:
                self.train_df.loc[index, 'int_label'] = Configs.classes[NONE]
        
    def split_to_folds(self, test_size):
        skf = StratifiedShuffleSplit(n_splits=Configs.n_folds, test_size=test_size)
        for n, (train_index, val_index) in enumerate(skf.split(X=self.train_df.index, y=self.train_df['int_label'])):
            self.train_df.loc[self.train_df.iloc[val_index].index, 'fold'] = int(n)
        self.train_df = self.train_df[self.train_df['fold'].notna()]
    
    def get_train_df(self, fold_number): 
        if fold_number >= 0 and fold_number < Configs.n_folds:
            return self.train_df[self.train_df['fold'] != fold_number]

    def get_val_df(self, fold_number):
        if fold_number >= 0 and fold_number < Configs.n_folds:
            return self.train_df[self.train_df['fold'] == fold_number]

**Visualize distribution of labels over folds**

In [None]:
# Plot distibution
def plot_folds(data_folds):
    nrows = Configs.n_folds//2
    if Configs.n_folds%2 != 0:
        nrows += 1
    
    fig, ax = plt.subplots(nrows=nrows, ncols=2, figsize=(20,10))
    row = 0
    for fold in range(Configs.n_folds):
        if fold%2 == 0:
            col = 0
            if fold != 0:
                row += 1
        else:
            col = 1

        labels_count = {}
        labels_count[OPACITY] = len(data_folds.train_df[((data_folds.train_df['fold'] == fold)&(data_folds.train_df['int_label'] == Configs.classes[OPACITY]))])
        labels_count[NONE] = len(data_folds.train_df[((data_folds.train_df['fold'] == fold)&(data_folds.train_df['int_label'] == Configs.classes[NONE]))])
        
        ax[row, col].bar(list(labels_count.keys()), list(labels_count.values()))

        for j, value in enumerate(labels_count.values()):
            ax[row, col].text(j, value+2, str(value), color='#267DBE', fontweight='bold')

        ax[row, col].grid(axis='y', alpha=0.75)
        ax[row, col].set_title("For fold #{}".format(fold), fontsize=15)
        ax[row, col].set_ylabel("count")

In [None]:
data_folds = DataFolds(train_df)#, continue_train=True)
data_folds.train_df.to_csv("./splitted_train_df.csv", index=False)

In [None]:
plot_folds(data_folds)

<a id="section-four"></a>
## **Data augmentation using Albumentations**

In [None]:
def get_transforms(train=True):
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=10),
            A.OneOf([
                A.Blur(blur_limit=3, p=0.5),
                A.MedianBlur(blur_limit=3, p=0.5),
                A.GaussNoise(p=0.5),
                A.IAASharpen(p=0.5)
                ],p=0.4),
            A.CLAHE(p=0.6),
            A.Resize(height=Configs.img_size, width=Configs.img_size, p=1),])

    else:
        return A.Compose([
            A.Resize(height=Configs.img_size, width=Configs.img_size, p=1),])

<a id="section-five"></a>
## **Custom dataset**

In [None]:
def get_dicom_img(path):
    data_file = pydicom.dcmread(path)
    img = apply_voi_lut(data_file.pixel_array, data_file)

    if data_file.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    # Rescaling grey scale between 0-255 and convert to uint
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)

    return img

In [None]:
class Covid19Dataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.transform = transform

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['dicom_path']
        
        img = get_dicom_img(img_path)
        label = row['int_label']
        
        if self.transform:
            transformed = self.transform(image=img)
            img = transformed['image']
           
        # normalize img
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img /= 255.0
        
        # convert image into a torch.Tensor
        img = torch.as_tensor(img, dtype=torch.float32)
        #idx = torch.tensor([idx])
        
        # permute image to [C,H,W] from [H,W,C] and normalize
        img = img.permute(2, 0, 1)
        
        return img, label
    
    def __len__(self):
        return len(self.df)

In [None]:
def get_dataset_fold(data_folds, fold, train=True):
    if train:
        return Covid19Dataset(data_folds.get_train_df(fold), transform=get_transforms(train))
    return Covid19Dataset(data_folds.get_val_df(fold), transform=get_transforms(train))

<a id="section-six"></a>
## **Fitter**

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class Fitter:
    def __init__(self, dir, model_name, verbose=True):
        # create pretrained timm model by name
        self.model_name = model_name
        self.model = timm.create_model(model_name, pretrained=True, num_classes=len(Configs.classes))
        self.verbose = verbose
        
        self.epoch = 0 
        self.dir = dir
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
        
        self.log_path = os.path.join(self.dir, 'log.txt')
        self.best_summary_loss = 10**5

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=TrainConfigs.lr)
        self.scheduler = TrainConfigs.SchedulerClass(self.optimizer, **TrainConfigs.scheduler_params) ########
        self.log(f'Fitter prepared. Device is {Configs.device}')
        
    def fit(self, fold, train_loader, validation_loader):
        self.model.to(Configs.device)
        self.log("Fold {}".format(fold))
        for e in range(self.epoch, TrainConfigs.n_epochs):
            if self.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
            
            # train one epoch
            t = time.time()
            summary_loss, summary_accuracy = self.train_one_epoch(train_loader)
            
            # log train losses to console/log file
            self.log(f'[RESULT]: Train. Epoch: {self.epoch},\ttotal loss: {summary_loss.avg:.5f},\ttotal accuracy: {summary_accuracy.avg:.5f},\ttime: {(time.time() - t):.5f}')
            # log train losses to wandb
            run.log({f"{self.model_name}/train/total_loss_fold{fold}": summary_loss.avg})

            # validate one epoch
            t = time.time()
            summary_loss, summary_accuracy = self.validation_one_epoch(validation_loader)
            
            # log val losses to console/log file
            self.log(f'[RESULT]: Val. Epoch: {self.epoch},\ttotal loss: {summary_loss.avg:.5f},\ttotal accuracy: {summary_accuracy.avg:.5f},\ttime: {(time.time() - t):.5f}')
            # log val losses to wandb
            run.log({f"{self.model_name}/val/total_loss_fold{fold}": summary_loss.avg})
            
            # save last checkpoint
            self.save(os.path.join(self.dir, 'last-checkpoint.bin'))
            wandb.save(os.path.join(self.dir, 'checkpoint-epoch{e}.bin'))
            
            # update best val losses and save best checkpoint if needed
            if summary_loss.avg < self.best_summary_loss:
                self.best_summary_loss = summary_loss.avg
                self.model.eval()
                self.save(os.path.join(self.dir, 'best-checkpoint.bin'))
                wandb.save(os.path.join(self.dir, 'best-checkpoint.bin'))
                for path in sorted(glob(os.path.join(self.dir, 'best-checkpoint.bin')))[:-3]:
                    os.remove(path)

            self.scheduler.step(metrics=summary_loss.avg) 

            self.epoch += 1
                  
    def train_one_epoch(self, train_loader):
        self.model.train()
        summary_loss = AverageMeter()
        summary_accuracy = AverageMeter()
        
        t = time.time()
        for step, (images, labels) in enumerate(train_loader):
            if self.verbose:
                    print(f'Train Step {step}/{len(train_loader)},\t' + \
                        f'total_loss: {summary_loss.avg:.5f},\t' + \
                        f'total_accuracy: {summary_accuracy.avg:.5f},\t' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            
            images = images.to(Configs.device).float()
            labels = labels.to(Configs.device).long()
            batch_size = images.shape[0]
           
            self.optimizer.zero_grad()
            
            logits = self.model(images)  
            preds = logits.argmax(dim=1 , keepdim=True)
            
            loss = TrainConfigs.loss_fn(logits, labels)
            accuracy = torchmetrics.functional.accuracy(labels, preds) ### labels, preds
            
            loss.backward()
            self.optimizer.step()
            
            summary_loss.update(loss.detach().item(), batch_size)
            summary_accuracy.update(accuracy, batch_size)
            
            del images, labels
            torch.cuda.empty_cache()

        return summary_loss, summary_accuracy
    
    def validation_one_epoch(self, val_loader):
        self.model.eval()
        summary_loss = AverageMeter()
        summary_accuracy = AverageMeter()
        
        t = time.time()
        for step, (images, labels) in enumerate(val_loader):
            if self.verbose:
                    print(
                        f'Val Step {step}/{len(val_loader)}, ' + \
                        f'total_loss: {summary_loss.avg:.5f}, ' + \
                        f'total_accuracy: {summary_accuracy.avg:.5f},\t' + \
                        f'time: {(time.time() - t):.5f}', end='\r'
                    )
            with torch.no_grad():
                images = images.to(Configs.device).float()
                labels = labels.to(Configs.device).long()
                batch_size = images.shape[0]
    
                logits = self.model(images)
                preds = logits.argmax(dim=1, keepdim=True)
            
                loss = TrainConfigs.loss_fn(logits, labels)
                accuracy = torchmetrics.functional.accuracy(labels, preds)
                
                summary_loss.update(loss.detach().item(), batch_size)
                summary_accuracy.update(accuracy, batch_size)
                
            del images, labels
            torch.cuda.empty_cache()
            
        return summary_loss, summary_accuracy
    
    # save checkpoint to path
    def save(self, path):
        self.model.eval()
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_summary_loss': self.best_summary_loss,
            'epoch': self.epoch,
        }, path)

    # load checkpoint from path
    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_summary_loss = checkpoint['best_summary_loss']
        self.epoch = checkpoint['epoch'] + 1
    
    # log to console/log file
    def log(self, message):
        if self.verbose:
            print(message)
        with open(self.log_path, 'a+') as logger:
            logger.write(f'{message}\n')

<a id="section-seven"></a>
## **Train**

**Train configurations**

In [None]:
class TrainConfigs:
    n_epochs = 10
    lr = 0.001
    loss_fn = nn.CrossEntropyLoss() 
    SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau
    scheduler_params = dict(
        mode='min',
        factor=0.5,
        patience=2,
        verbose=True, 
        threshold=0.0001,
        threshold_mode='abs',
        min_lr=1e-8,
    )

**Run train** 

In [None]:
def run_training(model_name, fold, train_dataset, val_dataset):
    # create train/validation data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=Configs.batch_size,
        sampler=RandomSampler(train_dataset),
        pin_memory=False,
        drop_last=True,
        num_workers=Configs.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=Configs.batch_size,
        num_workers=Configs.num_workers,
        shuffle=False,
        sampler=SequentialSampler(val_dataset),
        pin_memory=False,
    )
    
    # create and run fitter for model
    fitter = Fitter(f'./{model_name}/{model_name}_fold{fold}', model_name)
    fitter.fit(fold, train_loader, val_loader)

**Run train for 5 models over the different folds**

In [None]:
# wanted to try more models but eventually use only one - inception_resnet_v2
models = ['inception_resnet_v2', 'pnasnet5large', 'inception_v4']

In [None]:
fold = 0
train_dataset = get_dataset_fold(data_folds, fold)
val_dataset = get_dataset_fold(data_folds, fold, train=False)

run_training(models[0], fold, train_dataset, val_dataset)

In [None]:
for fold in range(Configs.n_folds):
    train_dataset = get_dataset_fold(data_folds, fold)
    val_dataset = get_dataset_fold(data_folds, fold, train=False)

    run_training(models[0], fold, train_dataset, val_dataset)

**zip results and save files**

In [None]:
!zip -r ./inception_resnet_v2.zip ./inception_resnet_v2

In [None]:
from IPython.display import FileLink
FileLink('./inception_resnet_v2.zip')