# High-Performance Large-Scale Image Recognition Without Normalization

![](https://pbs.twimg.com/media/EuB11mwXEAAaG_T.png:large)

Google's DeepMind recently published their latest CNN network called Normalization Free Networks (NFNets). Paper claims state of the art imagenet results and faster speed on their tests. So in this notebook I'm going to give it a try with some changes on the problem this competition gives us, hope you find it useful!

Before I start I wanted to thank sin for his great notebook [here](https://www.kaggle.com/underwearfitting/single-fold-training-of-resnet200d-lb0-965). I took it as a starting point for this notebook; go check it out. Anyways let's get started!

## Main Points of NFNets:
*From the paper itself:*

* Authors propose **Adaptive Gradient Clipping (AGC)**, which clips gradients based on the unit-wise ratio of gradient norms to parameter norms, and we demonstrate that AGC allows them to train Normalizer-Free Networks with larger batch sizes and stronger data augmentations.

* The authors have designed a family of **Normalizer-Free ResNets, called NFNets**, which set new state-of-the-art validation accuracies on ImageNet for a range of training latencies(See the Image above). The NFNet-F1 model achieves similar accuracy to EfficientNet-B7 while being 8.7Ã— faster to train, and largest model sets a new overall state of the art without extra data of 86.5% top-1 accuracy.

* The authors have show that NFNets achieve substantially higher validation accuracies than batch-normalized networks when fine-tuning on ImageNet after pre-training on a large private dataset of 300 million labelled images. The best model achieves 89.2% top-1 after fine-tuning

## Some Notes About Notebook:

* I've changed some settings to make this work in Kaggle's GPU hours. So this is kinda **light version**.
* This is only **single fold** training because of same reasons above.
* You can play with config to your liking, to get better results.
* If you want to switch to SAM optimizer **expect 50%+ longer training times**, since it does forward-backward twice.
* Again for SAM optimizer, don't forget to reduce lr for better results.
* AGC is optional but if you getting nan losses you might want to enable it.
* You can train for other folds as well and get the weights for inference notebooks. CV/LB balance was solid in my trials...

### If you have any further questions leave a comment and **if you liked this work please give an upvote**, thanks!

# Checking Device

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')
else:
    print(gpu_info)

# Configuration

Here's our command center, we can try some different settings to get better results, hope you find some and share back to community!

In [None]:
class CFG:
    
    n_splits = 5 
    
    fold_id = 0 # Fold to train

    image_size = 512 
    seed = 42
    init_lr = 5e-4
    batch_size = 64
    valid_batch_size = 64
    n_epochs = 15
    num_workers = 8

    use_amp = True  
    early_stop = 5
    
    AGC = False # Adaptive Gradient Clipping
    optimizer = 'Ranger' #Ranger, SAM, Adam

    model_name = 'eca_nfnet_l0'
    train_dir = '../input/trainfolds/train_folds.csv'
    data_dir = '../input/ranzor-clip-resized-data-512-256/trainXray_512'
    
    target_cols = ['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 'NGT - Abnormal', 
           'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal',
           'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']
    
model_dir = f'weights/'
! mkdir $model_dir

# Loading Libraries

In [None]:
import sys
sys.path.append('../input/nfnets/pytorch-image-models-master')
import timm
from timm.utils.agc import adaptive_clip_grad


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pylab import rcParams
import math

import os
import time
import cv2
import PIL.Image
import random
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR 
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import gc
from sklearn.metrics import roc_auc_score


from warnings import filterwarnings
filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
plt.style.use('ggplot')

In [None]:
def seed_everything(seed):
    
    """Seeding everything for consistent experiments..."""
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

In [None]:
# Loading metadata

train_df = pd.read_csv(CFG.train_dir)

# Metadata & Exploratory Data Analysis

Here we try to find some insights from data metadata given to us for getting better scores and undserstanding problem better.

## Target Labels

Let's check our target distribution to see if there's a common pattern. We observe that labels are not that balanced, while CVC - Normal beign dominant class, ETT - Abnormal sits at the bottom beign rarest one.

In [None]:
# Counting target values.

targ_cts=train_df.iloc[:,1:-2].sum(axis=0)
fig = plt.figure(figsize=(12,6))
sns.barplot(y=targ_cts.sort_values(ascending=False).index, x=targ_cts.sort_values(ascending=False).values, palette='inferno')
plt.show()

We observe that most of the samples having only one label per examination. But we also see decent part of the samples having several labels at the same time, making our problem a **Multi-Label Classification**...

I'm not an expert on the subject but this might be caused by several catheters connected to patient at the same time...

In [None]:
# Labels per sample.

plt.figure(figsize=(16,6))
features = train_df.iloc[:,1:-2].columns.values[1:]
plt.title('Total Target Score Counts', weight='bold')
sns.countplot(train_df[features].sum(axis=1), palette='inferno')
plt.xlabel('Total Number of Targets per Sample')
plt.legend()
plt.show()



## Patient ID's:

After inspecting patient id's closely we notice there are some patients with high number of chest x-ray samples in our training set. This might lead data leakage in our cross validation if we don't stratify them accordingly.


In [None]:
# counting patient id's

plt.figure(figsize=(20,5))
pat_counts=train_df.PatientID.value_counts().reset_index().iloc[:25,:]
pat_counts.columns=['PatientID','Count']
sns.barplot(x='PatientID', y='Count',data=pat_counts, palette='inferno')
plt.xticks(rotation=60)
plt.show()


After stratifying our samples by patient and their labels we can say that we have healthy target distribution between folds now...

In [None]:
# Checking fold target distributions

fold_dist=train_df.groupby('fold').sum().reset_index()
fig, axs = plt.subplots(1, 11, figsize=(40,4))
for i,j in enumerate(fold_dist.columns[1:].tolist()):

   axs[i].bar(x=fold_dist.fold, height=fold_dist.loc[:,j])
   axs[i].set_title(j)

plt.show()

# NFNet Pretrained Model

Here we define our CNN model using pretrained ImageNet weights.

In [None]:
class NFNetModel(nn.Module):
    
    "Main nfnet model class. Where we load pretrained imagenet weights for it and customizing head layer for our case in competition."

    def __init__(self, num_classes=len(CFG.target_cols), model_name=CFG.model_name, pretrained=False):
        super(NFNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        #if pretrained:
        #    pretrained_path = f'../input/nfnet-pretrained-weights/{model_name}.pth'
        #    self.model.load_state_dict(torch.load(pretrained_path))
            
        self.model.head.fc = nn.Linear(self.model.head.fc.in_features, num_classes)
        
    def forward(self, x):
        x = self.model(x)
        return x

# Augmentations

On paper authors showed that using heavy augmentations leading better results. Here we do some mid level augmentations just in case...

In [None]:
# applying some augmentations for regularizing effect

transforms_train = A.Compose([
   A.RandomResizedCrop(CFG.image_size, CFG.image_size, scale=(0.95, 1), p=1), 
   A.HorizontalFlip(p=0.5),
   #A.ShiftScaleRotate(rotate_limit=(-8, 8),p=0.5),
   A.HueSaturationValue(hue_shift_limit=(-10, 10), sat_shift_limit=(-10, 10), val_shift_limit=(-10, 10), p=0.5),
   A.RandomBrightnessContrast(always_apply=False, brightness_limit=(-0.05, 0.05), contrast_limit=(-0.05, 0.05), brightness_by_max=True, p=0.7),
   A.CLAHE(clip_limit=(1,4), p=0.25),

  A.Cutout(max_h_size=int(CFG.image_size * 0.05), max_w_size=int(CFG.image_size * 0.05), num_holes=12, p=0.5),
  A.Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
         ToTensorV2(p=1.0)
])

transforms_valid = A.Compose([
    A.Resize(CFG.image_size, CFG.image_size),
    A.Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0)
])

## Train Loader

This is the class for loading our images and indexes them together with metadata for the training.

In [None]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.file_names = df['StudyInstanceUID'].values
        self.labels = df[CFG.target_cols].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{CFG.data_dir}/{file_name}.jpg'
        image = cv2.imread(file_path)        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).float()
        return image, label

## Checking the Images

In [None]:
# loading the images with augmentations

train_dataset = TrainDataset(train_df, transform=transforms_train)

fig, axs = plt.subplots(1, 5, figsize=(40,12))

for i in range(5):
    image, label = train_dataset[i]
    axs[i].imshow(image[0])
    axs[i].title.set_text(f'Target Labels: {label}')

plt.show() 

## Train Function

Here we set our main training function. Key points are we implement AGC the Adaptive Gradient Clipping where authors used on the official paper. Then we do two forward-backward functions because SAM needs two forward-backward passes to estime the "sharpness-aware" gradient. It slows our training but that's the optimizer where authors used on their paper.

In [None]:
def train_func(train_loader):
    
    """ Main training function: Takes loaded images to predict labels, computes losses between predicted and training labels, clip gradients, return updated losses. """
    
    model.train()
    bar = tqdm(train_loader)
    if CFG.use_amp:
        scaler = torch.cuda.amp.GradScaler()
    losses = []
    scores = []
    for batch_idx, (images, targets) in enumerate(bar):

        images, targets = images.to(device), targets.to(device)
        
        if CFG.use_amp:
            if CFG.optimizer =='SAM':
                with torch.cuda.amp.autocast():
                    preds_first = model(images)
                    loss = trn_criterion(preds_first, targets)

                scaler.scale(loss).backward()   
                if CFG.AGC:
                    adaptive_clip_grad(model.parameters(), clip_factor=0.01, eps=1e-3, norm_type=2.0)
                optimizer.first_step(zero_grad=True)

                with torch.cuda.amp.autocast():
                    preds_second = model(images)
                    loss_second = trn_criterion(preds_second, targets)

                scaler.scale(loss_second).backward()
                if CFG.AGC:
                    adaptive_clip_grad(model.parameters(), clip_factor=0.01, eps=1e-3, norm_type=2.0)
                optimizer.second_step(zero_grad=True)
                if not CFG.AGC:
                    scaler.update()
            else:
                with torch.cuda.amp.autocast():
                    preds = model(images)
                    loss = trn_criterion(preds, targets)
                    scaler.scale(loss).backward()
                    if CFG.AGC:
                        adaptive_clip_grad(model.parameters(), clip_factor=0.01, eps=1e-3, norm_type=2.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                
            
            

            
        else:
            output = model(images)
            loss = trn_criterion(output, targets)
            loss.backward()
            if CFG.AGC:
                adaptive_clip_grad(model.parameters(), clip_factor=0.01, eps=1e-3, norm_type=2.0)
            optimizer.step()
            optimizer.zero_grad()


        losses.append(loss.item())
        

        bar.set_description(f'Mean Loss: {np.mean(losses):.5f}')

    loss_train = np.mean(losses)
    return loss_train


def valid_func(valid_loader):
    
    """ Main validation function: Takes loaded images to predict labels, computes losses between predicted and valid labels, clip gradients, return updated losses. """
    
    
    model.eval()
    bar = tqdm(valid_loader)

    PROB = []
    TARGETS = []
    losses = []
    PREDS = []
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(bar):

            images, targets = images.to(device), targets.to(device)
            output = model(images)
            PREDS += [output.sigmoid()]
            TARGETS += [targets.detach().cpu()]
            loss = val_criterion(output, targets)
            losses.append(loss.item())
            bar.set_description(f'Loss: {loss.item():.5f}')
            
    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    roc_auc = macro_multilabel_auc(TARGETS, PREDS)
    loss_valid = np.mean(losses)
    return loss_valid, roc_auc

# Utils

In [None]:
def macro_multilabel_auc(label, pred):
    
    """A function recevies couple columns as inputs, predicts AUC score column-wise and returns mean scores."""
    
    aucs = []
    for i in range(len(CFG.target_cols)):
        aucs.append(roc_auc_score(label[:, i], pred[:, i]))
    print(np.round(aucs, 4))
    return np.mean(aucs)

## Sharpness-Aware Minimization for Efficiently Improving Generalization

SAM Optimizer implementations is taken from [here.](https://github.com/davda54/sam) 

> SAM simultaneously minimizes loss value and loss sharpness. In particular, it seeks parameters that lie in neighborhoods having uniformly low loss. SAM improves model generalization and yields SoTA performance for several datasets. Additionally, it provides robustness to label noise on par with that provided by SoTA procedures that specifically target learning with noisy labels.


![](https://raw.githubusercontent.com/davda54/sam/main/img/loss_landscape.png)
*ResNet loss landscape at the end of training with and without SAM. Sharpness-aware updates lead to a significantly wider minimum, which then leads to better generalization properties.*


You can find about it more here on the official paper:
https://arxiv.org/abs/2010.01412

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

## Ranger Optimizer

Ranger is a synergistic optimizer combining RAdam (Rectified Adam) and LookAhead. I choose this one over SAM on this notebook for timing purposes.

In [None]:
def centralized_gradient(x, use_gc=True, gc_conv_only=False):
    '''credit - https://github.com/Yonghongwei/Gradient-Centralization '''
    if use_gc:
        if gc_conv_only:
            if len(list(x.size())) > 3:
                x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
        else:
            if len(list(x.size())) > 1:
                x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True))
    return x


class Ranger(Optimizer):

    def __init__(self, params, lr=1e-3,                       # lr
                 alpha=0.5, k=5, N_sma_threshhold=5,           # Ranger options
                 betas=(.95, 0.999), eps=1e-5, weight_decay=0,  # Adam options
                 # Gradient centralization on or off, applied to conv layers only or conv + fc layers
                 use_gc=True, gc_conv_only=False, gc_loc=True
                 ):

        # parameter checks
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        if not lr > 0:
            raise ValueError(f'Invalid Learning Rate: {lr}')
        if not eps > 0:
            raise ValueError(f'Invalid eps: {eps}')

        # parameter comments:
        # beta1 (momentum) of .95 seems to work better than .90...
        # N_sma_threshold of 5 seems better in testing than 4.
        # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.

        # prep defaults and init torch.optim base
        defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas,
                        N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

        # adjustable threshold
        self.N_sma_threshhold = N_sma_threshhold

        # look ahead params

        self.alpha = alpha
        self.k = k

        # radam buffer for state
        self.radam_buffer = [[None, None, None] for ind in range(10)]

        # gc on or off
        self.gc_loc = gc_loc
        self.use_gc = use_gc
        self.gc_conv_only = gc_conv_only
        # level of gradient centralization
        #self.gc_gradient_threshold = 3 if gc_conv_only else 1

        print(
            f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
        if (self.use_gc and self.gc_conv_only == False):
            print(f"GC applied to both conv and fc layers")
        elif (self.use_gc and self.gc_conv_only == True):
            print(f"GC applied to conv layers only")

    def __setstate__(self, state):
        print("set state called")
        super(Ranger, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
        # Uncomment if you need to use the actual closure...

        # if closure is not None:
        #loss = closure()

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()

                if grad.is_sparse:
                    raise RuntimeError(
                        'Ranger optimizer does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]  # get state dict for this param

                if len(state) == 0:  # if first time to run...init dictionary with our desired entries
                    # if self.first_run_check==0:
                    # self.first_run_check=1
                    #print("Initializing slow buffer...should not see this at load from saved model!")
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state['slow_buffer'] = torch.empty_like(p.data)
                    state['slow_buffer'].copy_(p.data)

                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # GC operation for Conv layers and FC layers
                # if grad.dim() > self.gc_gradient_threshold:
                #    grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
                if self.gc_loc:
                    grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only)

                state['step'] += 1

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # compute mean moving avg
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
                            N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                # if group['weight_decay'] != 0:
                #    p_data_fp32.add_(-group['weight_decay']
                #                     * group['lr'], p_data_fp32)

                # apply lr
                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    G_grad = exp_avg / denom
                else:
                    G_grad = exp_avg

                if group['weight_decay'] != 0:
                    G_grad.add_(p_data_fp32, alpha=group['weight_decay'])
                # GC operation
                if self.gc_loc == False:
                    G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only)

                p_data_fp32.add_(G_grad, alpha=-step_size * group['lr'])
                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    # get access to slow param tensor
                    slow_p = state['slow_buffer']
                    # (fast weights - slow weights) * alpha
                    slow_p.add_(p.data - slow_p, alpha=self.alpha)
                    # copy interpolated weights to RAdam param tensor
                    p.data.copy_(slow_p)

        return loss

In [None]:
class Mish_func(torch.autograd.Function):
    
    """from: https://github.com/tyunist/memory_efficient_mish_swish/blob/master/mish.py"""
    
    @staticmethod
    def forward(ctx, i):
        result = i * torch.tanh(F.softplus(i))
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
  
        v = 1. + i.exp()
        h = v.log() 
        grad_gh = 1./h.cosh().pow_(2) 

        # Note that grad_hv * grad_vx = sigmoid(x)
        #grad_hv = 1./v  
        #grad_vx = i.exp()
        
        grad_hx = i.sigmoid()

        grad_gx = grad_gh *  grad_hx #grad_hv * grad_vx 
        
        grad_f =  torch.tanh(F.softplus(i)) + i * grad_gx 
        
        return grad_output * grad_f 


class Mish(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        pass
    def forward(self, input_tensor):
        return Mish_func.apply(input_tensor)

In [None]:
def replace_activations(model, existing_layer, new_layer):
    
    """A function for replacing existing activation layers"""
    
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            model._modules[name] = replace_activations(module, existing_layer, new_layer)

        if type(module) == existing_layer:
            layer_old = module
            layer_new = new_layer
            model._modules[name] = layer_new

    return model

### We replace our default activation functions to Mish since it was recommended on Ranger implementation.

In [None]:
# loading the model

model = NFNetModel(pretrained=True)
if CFG.optimizer == 'Ranger':
    model = replace_activations(model, nn.ReLU, Mish())
model = model.to(device)

### Here we set our train and validation criterions (They both same here for now but adds flexibility for future uses). Our baseline optimizer, SAM optimizer, schedulers etc...

In [None]:
# setting criterions, optimizers, folds to train etc.

val_criterion = nn.BCEWithLogitsLoss()
trn_criterion = nn.BCEWithLogitsLoss()

# for sam optimizer you can change the base optimizer to get better results

if CFG.optimizer == 'SAM':
    base_optimizer = torch.optim.Adam
    optimizer = SAM(model.parameters(), base_optimizer, lr=CFG.init_lr)
if CFG.optimizer == 'Ranger':
    optimizer = Ranger(model.parameters(), CFG.init_lr)
if CFG.optimizer == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(),lr=CFG.init_lr)
    

    
# here you can experiment with other schedulers too, they have decent impact on this competition

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, CFG.n_epochs, eta_min=1e-7)


train_df_this = train_df[train_df['fold'] != CFG.fold_id]
df_valid_this = train_df[train_df['fold'] == CFG.fold_id]

dataset_train = TrainDataset(train_df_this, transform=transforms_train)
dataset_valid = TrainDataset(df_valid_this, transform=transforms_valid)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=CFG.batch_size, shuffle=True,  num_workers=CFG.num_workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=CFG.valid_batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)

### Learning Rates

On official paper of NFNet's authors used gradient warming for 5 epochs and then cosine decay to 0 next epochs. In my personal experiments it didn't give better results so we only use cosine annealing here...

In [None]:
# visualizing the learning rate

lrs = []

modelt = torch.nn.Linear(2, 1)

optimizert = torch.optim.Adam(modelt.parameters(),lr=CFG.init_lr)

schedulert = torch.optim.lr_scheduler.CosineAnnealingLR(optimizert, CFG.n_epochs, eta_min=1e-7)

for i in range(CFG.n_epochs):
    optimizert.step()
    lrs.append(optimizert.param_groups[0]["lr"])
    schedulert.step()
plt.title('Learning Rate over Epochs')
plt.plot(lrs)
plt.show()

# Training Single Fold

In [None]:
# single fold training

log = {}
roc_auc_max = 0.
loss_min = 99999
not_improving = 0


for epoch in range(1, CFG.n_epochs+1):
    
    
    loss_train = train_func(train_loader)
    loss_valid, roc_auc = valid_func(valid_loader)

    log['loss_train'] = log.get('loss_train', []) + [loss_train]
    log['loss_valid'] = log.get('loss_valid', []) + [loss_valid]
    log['lr'] = log.get('lr', []) + [optimizer.param_groups[0]["lr"]]
    log['roc_auc'] = log.get('roc_auc', []) + [roc_auc]

    content = time.ctime() + ' ' + f'Fold: {CFG.fold_id}, Epoch: {epoch}/{CFG.n_epochs}, lr: {optimizer.param_groups[0]["lr"]:.7f}, loss_train: {loss_train:.5f}, loss_valid: {loss_valid:.5f}, roc_auc: {roc_auc:.6f}.'
    print(content)
    not_improving += 1
    
    scheduler.step()
    
    if roc_auc > roc_auc_max:
        print(f'roc_auc_max ({roc_auc_max:.6f} --> {roc_auc:.6f}). Saving model ...')
        torch.save(model.state_dict(), f'{model_dir}{CFG.model_name}_fold{CFG.fold_id}_best_AUC.pth')
        roc_auc_max = roc_auc
        not_improving = 0

    if loss_valid < loss_min:
        loss_min = loss_valid
        torch.save(model.state_dict(), f'{model_dir}{CFG.model_name}_fold{CFG.fold_id}_best_loss.pth')
        
    if not_improving == CFG.early_stop:
        print('Early Stopping...')
        break
        


torch.save(model.state_dict(), f'{model_dir}{CFG.model_name}_fold{CFG.fold_id}_final.pth')

# Summary

On this notebook I wanted to try latest state of the art(*) CNN models: Normalization Free Networks (NFNets) with some other promising additions. I hope you enjoyed reading it as much as I did while creating this content.

Again if you enjoyed this work please leave an upvote, thanks!