In [None]:
from IPython.display import HTML
file = open("../input/notebookassets/custom.css")
HTML("<style>"+file.read()+"</style>")

# RANZCR Complete Modular PyTorch trainer ðŸ‘¾

This is the training pipeline that I've used in Cassava Leaf Disease Classification. Feel free to fork it, and make your changes!

<p>If you like this notebook, please give it an <span style="font-size:24px;color:red">Upvote!</span></p>

In [None]:
import sys
sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master")

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

import cv2
from tqdm.notebook import tqdm

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

import timm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

warnings.simplefilter("ignore")

In [None]:
class Config:
    CFG = {
        'img_size': 384,
    }

In [None]:
def plot_results(train_acc, valid_acc, train_loss, valid_loss, nb_epochs):
    epochs = [i for i in range(nb_epochs)]
    
    fig, ax = plt.subplots(1, 2)
    fig.set_size_inches(20, 10)
    
    ax[0].plot(epochs, train_acc, 'go-', label='Training Accuracy')
    ax[0].plot(epochs, valid_acc, 'ro-', label='Validation Accuracy')
    ax[0].set_title('Training & Validation Accuracy')
    ax[0].legend()
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Accuracy')
    
    ax[1].plot(epochs, train_loss, 'go-', label='Training Loss')
    ax[1].plot(epochs, valid_loss, 'ro-', label='Validation Loss')
    ax[1].set_title('Training & Validation Loss')
    ax[1].legend()
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Loss')
    
    plt.show()

In [None]:
class Augments:
    """
    Contains Train, Validation and Testing Augments
    """
    train_augments = Compose([
            RandomResizedCrop(Config.CFG['img_size'], Config.CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ],p=1.)
    
    valid_augments = Compose([
            CenterCrop(Config.CFG['img_size'], Config.CFG['img_size'], p=1.),
            Resize(Config.CFG['img_size'], Config.CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class ResNextModel(nn.Module):
    """
    Model Class for ResNext Model Architectures
    """
    def __init__(self, num_classes=11, model_name='resnext50d_32x4d', pretrained=True):
        super(ResNextModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x

class ResNetModel(nn.Module):
    """
    Model Class for ResNet Models
    """
    def __init__(self, num_classes=11, model_name='resnet18', pretrained=True):
        super(ResNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
class RANCZRData(Dataset):
    def __init__(self, df, num_classes=5, is_train=True, augments=None, img_size=Config.CFG['img_size'], img_path="../input/ranzcr-clip-catheter-line-classification/train"):
        super().__init__()
        self.df = df.sample(frac=1).reset_index(drop=True)
        self.num_classes = num_classes
        self.is_train = is_train
        self.augments = augments
        self.img_size = img_size
        self.img_path = img_path
#         img_path="../input/ranzcr-clip-catheter-line-classification/train"
        
        # Add the Right Image Path
#         self.df['StudyInstanceUID'] = self.df['StudyInstanceUID'].apply(lambda x: os.path.join(self.img_path, x + ".jpg"))
    
    def __getitem__(self, idx):
        image_id = self.df['StudyInstanceUID'].values[idx]
        image = cv2.imread(os.path.join(self.img_path, image_id + ".jpg"))
        image = image[:, :, ::-1]
        
        # Augments must be albumentations
        if self.augments:
            img = self.augments(image=image)['image']
        
        if self.is_train:
            label = self.df[self.df['StudyInstanceUID'] == image_id].values.tolist()[0][1:-1]
            return img, torch.tensor(label)
        
        return img
    
    def __len__(self):
        return len(self.df)

In [None]:
class Trainer:
    def __init__(self, train_dataloader, valid_dataloader, model, optimizer, loss_fn, val_loss_fn, scheduler, device="cuda:0", plot_results=True):
        """
        TODO: Implement the ROC-AUC Scheduler stuff
        """
        self.train = train_dataloader
        self.valid = valid_dataloader
        self.optim = optim
        self.loss_fn = loss_fn
        self.val_loss_fn = val_loss_fn
        self.scheduler = scheduler
        self.device = device
        self.plot_results = plot_results
    
    def train_one_cycle(self):
        """
        Runs one epoch of training, backpropagation, optimization and gets train accuracy
        """
        model.train()
        train_prog_bar = tqdm(self.train, total=len(self.train))

        all_train_labels = []
        all_train_preds = []
        
        running_loss = 0
        
        for xtrain, ytrain in train_prog_bar:
            xtrain = xtrain.to(device).float()
            ytrain = ytrain.to(device).float()
            
            with autocast():
                # Get predictions
                z = model(xtrain)

                # Training
                train_loss = self.loss_fn(z, ytrain)
                scaler.scale(train_loss).backward()
                
                scaler.step(self.optim)
                scaler.update()
                self.optim.zero_grad()

                # For averaging and reporting later
                running_loss += train_loss

                # Convert the predictions and corresponding labels to right form
                train_predictions = torch.argmax(z, 1).detach().cpu().numpy()
                train_labels = ytrain.detach().cpu().numpy()

                # Append current predictions and current labels to a list
                all_train_labels += [train_predictions]
                all_train_preds += [train_labels]

            # Show the current loss to the progress bar
            train_pbar_desc = f'loss: {train_loss.item():.4f}'
            train_prog_bar.set_description(desc=train_pbar_desc)
        
        # After all the batches are done, calculate the training accuracy
#         all_train_preds = np.concatenate(all_train_preds)
#         all_train_labels = np.concatenate(all_train_labels)
        
#         train_acc = (all_train_preds == all_train_labels).mean()
#         print(f"Training Accuracy: {train_acc:.4f}")
        
        # Now average the running loss over all batches and return
        train_running_loss = running_loss / len(self.train)
        print(f"Final Training Loss: {train_running_loss:.4f}")
        
        # Free up memory
        del all_train_labels, all_train_preds, train_predictions, train_labels, xtrain, ytrain, z
        
        return train_running_loss

    def valid_one_cycle(self):
        """
        Runs one epoch of prediction and validation accuracy calculation
        """        
        model.eval()
        
        valid_prog_bar = tqdm(self.valid, total=len(self.valid))
        
        with torch.no_grad():
            all_valid_labels = []
            all_valid_preds = []
            
            running_loss = 0
            
            for xval, yval in valid_prog_bar:
                xval = xval.to(device).float()
                yval = yval.to(device).float()
                
                val_z = model(xval)
                
                val_loss = self.val_loss_fn(val_z, yval)
                
                running_loss += val_loss.item()
                
                val_pred = torch.argmax(val_z, 1).detach().cpu().numpy()
                val_label = yval.detach().cpu().numpy()
                
                all_valid_labels += [val_label]
                all_valid_preds += [val_pred]
            
                # Show the current loss
                valid_pbar_desc = f"loss: {val_loss.item():.4f}"
                valid_prog_bar.set_description(desc=valid_pbar_desc)
            
            # Get the final loss
            final_loss_val = running_loss / len(self.valid)
            
            # Get Validation Accuracy
            all_valid_labels = np.concatenate(all_valid_labels)
            all_valid_preds = np.concatenate(all_valid_preds)
            
#             val_accuracy = (all_valid_preds == all_valid_labels).mean()
            print(f"Final Validation Loss: {final_loss_val:.4f}")
            
            # Free up memory
            del all_valid_labels, all_valid_preds, val_label, val_pred, xval, yval, val_z
            
        return (final_loss_val, model)

In [None]:
nb_epochs = 10
device = torch.device("cuda")

In [None]:
data = pd.read_csv("../input/ranzcr-clip-catheter-line-classification/train.csv")
data = data.sample(frac=1).reset_index(drop=True)

# 27,583 in Train, 2500 in Valid
train_split = data[2500:]
valid_split = data[:2500]

print(train_split.shape, valid_split.shape)

In [None]:
train_set = RANCZRData(df=train_split, augments=Augments.train_augments)
valid_set = RANCZRData(df=valid_split, augments=Augments.valid_augments)

train = DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    pin_memory=False,
    drop_last=False,
    num_workers=8
)

valid = DataLoader(
    valid_set,
    batch_size=32,
    shuffle=False,
    pin_memory=False,
    num_workers=8
)

model = ResNetModel().to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.001)
loss_fn_train = nn.BCEWithLogitsLoss()
loss_fn_val = nn.BCEWithLogitsLoss()

trainer = Trainer(
    train_dataloader=train,
    valid_dataloader=valid,
    model=model,
    optimizer=optim,
    loss_fn=loss_fn_train,
    val_loss_fn=loss_fn_val,
    scheduler=None,
    device=device,
)

In [None]:
# train_accs = []
# valid_accs = []
train_losses = []
valid_losses = []

scaler = GradScaler()

for epoch in range(nb_epochs):
    print(f"{'-'*20} EPOCH: {epoch+1}/{nb_epochs} {'-'*20}")

    # Run one training epoch
    current_train_loss = trainer.train_one_cycle()
#     train_accs.append(current_train_acc)
    train_losses.append(current_train_loss)

    # Run one validation epoch
    current_val_loss, op_model = trainer.valid_one_cycle()
#     valid_accs.append(current_val_acc)
    valid_losses.append(current_val_loss)

    # Empty CUDA cache
    torch.cuda.empty_cache()
    
    # Save the model every epoch
    print(f"Saving Model for this epoch...")
    torch.save(op_model.state_dict(), f"resnet18_model.pth")
    
# del train_set, valid_set, train, valid, model, optim, loss_fn, loss_fn_val, trainer, scaler
# torch.cuda.empty_cache()

# plot_results(train_accs, valid_accs, train_losses, valid_losses, nb_epochs)

In [None]:
# plt.figure(figsize=(10))
plt.plot(train_losses, 'go-')
plt.plot(valid_losses, 'ro-')
plt.show()