# Efficient Net Train
- Amp
- Focal Loss
- Fmix
- Early Stop
- Adam WR 

- if this helps, please do Upvote this code and the original üëçüèº<br><br>
< Reference Code > <br>
    -[Cutmix v.s. Fmix with Visualization](https://www.kaggle.com/khyeh0719/cutmix-v-s-fmix-with-visualization)<br>
    -[Pytorch Efficientnet Baseline [Train] AMP+Aug](https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug)<br>

In [None]:
import sys
package_path = ['../input/timmpackagelatestwhl', 
                '../input/image-fmix/FMix-master']
for pth in package_path:
    sys.path.append(pth)

In [None]:
import os
import pandas as pd
pd.set_option('display.max_row', None)
pd.set_option('display.max_columns', None)
import albumentations as albu
import matplotlib.pyplot as plt
import json
import seaborn as sns
import cv2
import albumentations as albu
import numpy as np
import random
from tqdm import tqdm
import warnings
warnings.filterwarnings(action='ignore')

import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler


!pip install ../input/timmpackagelatestwhl/timm-0.3.4-py3-none-any.whl
import timm
from fmix import sample_mask, make_low_freq_image, binarise_mask

# Load TrainSet

In [None]:
BASE_DIR="../input/cassava-leaf-disease-classification/"
TRAIN_IMAGES_DIR=os.path.join(BASE_DIR,'train_images/')
train_df=pd.read_csv(os.path.join(BASE_DIR,'train.csv'))

In [None]:
display(train_df.head())
print(train_df.shape)

# Helper Functions

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

def save_model(model, optimizer, scheduler, fold, epoch, save_every=False, best=False):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    if save_every == True:
        if not (os.path.isdir('./saved_model')): os.mkdir('./saved_model')
        torch.save(state, './saved_model/model_fold_{}_epoch_{}'.format(fold+1, epoch+1))
    if best == True:
        if not (os.path.isdir('./best_model')): os.mkdir('./best_model')
        torch.save(state, './best_model/model_fold_{}_epoch_{}'.format(fold+1, epoch+1))
        
class EarlyStopping:
    def __init__(self, patience):
        self.patience = patience
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model, optimizer, scheduler, fold, epoch):
        if self.val_loss_min == np.Inf:
            self.val_loss_min = val_loss
        elif val_loss > self.val_loss_min:
            self.counter += 1
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                print('Early Stopping - Fold {} Training is Stopping'.format(fold))
                self.early_stop = True
        else:  # val_loss < val_loss_min
            save_model(model, optimizer, scheduler, fold, epoch, best=True)
            print('*** Validation loss decreased ({} --> {}).  Saving model... ***'.\
                  format(round(self.val_loss_min, 6), round(val_loss, 6)))
            self.val_loss_min = val_loss
            self.counter = 0

# Data Augmentation

In [None]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(512, 512),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            CenterCrop(512, 512, p=1.),
            Resize(512, 512),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

# Data Loader

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, do_fmix=False, do_cutmix=False, output_label=True):
        self.df=df
        self.data_root=data_root
        self.transforms=transforms
        self.do_fmix = do_fmix
        self.do_cutmix = do_cutmix
        self.fmix_params={'alpha': 1., 
                          'decay_power': 3., 
                          'shape': (512, 512),
                          'max_soft': True, 
                          'reformulate': False}
        self.cutmix_params={'alpha': 1}
        self.output_label = output_label
        self.labels = self.df['label'].values
        
    def __getitem__(self,index):            
        img  = get_img(path="{}/{}".format(self.data_root, self.df.image_id.iloc[index]))
        
        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.output_label:
            target = self.df.label.iloc[index]
        
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                lam = np.clip(np.random.beta(self.fmix_params['alpha'], self.fmix_params['alpha']), 0.6, 0.7)
                mask = make_low_freq_image(self.fmix_params['decay_power'], self.fmix_params['shape'])
                mask = binarise_mask(mask, lam, self.fmix_params['shape'], self.fmix_params['max_soft'])
                fmix_ix = np.random.choice(self.df.shape[0], size=1)[0]
                fmix_img  = get_img("{}/{}".format(self.data_root, self.df.image_id.iloc[fmix_ix]))
                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']
                mask_torch = torch.from_numpy(mask)
                rate = mask.sum()/self.fmix_params['shape'][0]/self.fmix_params['shape'][1]
                
                img = mask_torch*img+(1.-mask_torch)*fmix_img
                target = rate*target + (1.-rate)*self.labels[fmix_ix]
                
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']),0.3,0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox(self.fmix_params['shape'], lam)
                rate = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (self.fmix_params['shape'][0] * self.fmix_params['shape'][1]))
                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]
                target = rate*target + (1.-rate)*self.labels[cmix_ix]
                    
        if self.output_label == True:
            return img, target
        else:
            return img
        
    def __len__(self):
        return len(self.df)

# Create Model

In [None]:
class Model(nn.Module):
    def __init__(self, model_name, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

# Focal Loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight

    def forward(self, input_, target):
        ce_loss = F.cross_entropy(input_, target,reduction=self.reduction,weight=self.weight) 
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

# train / Validation Functions

In [None]:
def prepare_dataloader(df, train_index, val_index, train_batch, valid_batch, num_workers, data_root=TRAIN_IMAGES_DIR):
    trainset = df.loc[train_index,:].reset_index(drop=True)
    validset = df.loc[val_index,:].reset_index(drop=True)
    train_dataset = CassavaDataset(trainset, data_root=data_root, transforms=get_train_transforms(), output_label=True, do_fmix=True)
    valid_dataset = CassavaDataset(validset, data_root=data_root, transforms=get_valid_transforms(), output_label=True, do_fmix=False)
    train_loader = DataLoader(train_dataset, batch_size=train_batch, pin_memory=False,drop_last=False,shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(valid_dataset, batch_size=train_batch, num_workers=num_workers, shuffle=False,pin_memory=False,)
    return train_loader, val_loader


def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler):
    model.train()
    lst_out = []
    lst_label = []
    avg_loss = 0

    status = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (images, labels) in status:
        images = images.to(device).float()
        labels = labels.to(device).long()
        with autocast():
            preds = model(images)
            lst_out += preds.argmax(1)
            lst_label += labels

            loss = loss_fn(preds, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            avg_loss += loss.item() / len(train_loader)
    scheduler.step()
    accuracy = accuracy_score(y_pred=torch.tensor(lst_out), y_true=torch.tensor(lst_label))
    print('{} epoch - train loss : {}, train accuracy : {}'.\
          format(epoch + 1, np.round(avg_loss,6), np.round(accuracy*100,2)))

def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler):
    model.eval()
    lst_val_out = []
    lst_val_label = []
    avg_val_loss = 0
    status = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (images, labels) in status:
        val_images = images.to(device).float()
        val_labels = labels.to(device).long()

        val_preds = model(val_images)
        lst_val_out += val_preds.argmax(1)
        lst_val_label += val_labels
        loss = loss_fn(val_preds, val_labels)
                       
        avg_val_loss += loss.item() / len(val_loader)
    accuracy = accuracy_score(y_pred=torch.tensor(lst_val_out), y_true=torch.tensor(lst_val_label))
    print('{} epoch - valid loss : {}, valid accuracy : {}'.\
          format(epoch + 1, np.round(avg_val_loss, 6), np.round(accuracy*100,2)))
    return avg_val_loss

# Main - Training

In [None]:
if __name__ == '__main__':
    train_batch = 16
    valid_batch = 32
    num_workers = 4
    seed = 719
    split = 5
    epochs = 100
    patience = 5

    n_class = 5
    model_arch = 'tf_efficientnet_b4_ns'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    seed_everything(seed)
    X_train = train_df.iloc[:, :-1]; Y_train = train_df.iloc[:, -1]
    cv = StratifiedKFold(n_splits=split, random_state=seed, shuffle=True)
    for fold, (train_index, val_index) in enumerate(cv.split(X_train, Y_train)):
        if fold == 0:
            continue
        torch.cuda.empty_cache()
        print('---------- Fold {} is training ----------'.format(fold + 1))
        print('Train Size : {}, Valid Size : {}'.format(len(train_index), len(val_index)))
        train_loader, val_loader = prepare_dataloader(train_df, train_index, val_index, train_batch, valid_batch, num_workers, data_root=TRAIN_IMAGES_DIR)
        model = Model(model_arch, n_class, pretrained=True).to(device)
        loss_tr = FocalLoss().to(device); loss_fn = FocalLoss().to(device)
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)
        scaler = GradScaler()
        early_stopping = EarlyStopping(patience=patience)
        for epoch in range(epochs):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler)
            save_model(model, optimizer, scheduler, fold, epoch, save_every=True)
            with torch.no_grad():
                val_loss = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None)
                early_stopping(val_loss, model, optimizer, scheduler, fold, epoch)
                if early_stopping.early_stop:
                    break

        del model, optimizer, train_loader, val_loader, scheduler, scaler
        torch.cuda.empty_cache()