# NF net training


In [None]:
import sys; 
sys.path.insert(0,'../input/timm-all-models/pytorch-image-models-master/pytorch-image-models-master')
# sys.path.insert(0, '../input/timm-all-models')

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

import cv2
from tqdm.notebook import tqdm

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

import timm
from timm.utils.agc import adaptive_clip_grad

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler

from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

warnings.simplefilter("ignore")
from pytorch_lightning import Trainer, seed_everything
import torch.nn.functional as F

In [None]:
class Config:
    seed=42
    n_epoch=1
    img_size= 224
    train_path='../input/plant-pathology-2021-fgvc8/train_images'
    test_path='../input/plant-pathology-2021-fgvc8/test_images'
    lr=1e-4
    weight_decay=0.001
    debug=False
    debug_sample=100
    train_batch=16
    test_batch=32
    path='../input/plant-pathology-2021-fgvc8/'
    
        
device = torch.device("cuda")

seed_everything(Config.seed)

In [None]:
df_all = pd.read_csv(Config.path + "train.csv")
df_all = df_all.sample(frac=1).reset_index(drop=True)
labels = list(df_all['labels'].value_counts().keys())
labels_dict = dict(zip(labels, range(12)))
if Config.debug:
    train_split = df_all[0:500]
    valid_split = df_all[500:550]
else:
    valid_split=df_all[:1000]
    train_split=df_all[1000:]
    

In [None]:
class Augments:
    """
    Contains Train, Validation Augments
    """
    train_augments = Compose([
            Resize(Config.img_size, Config.img_size),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ],p=1.)
    
    valid_augments = Compose([
            Resize(Config.img_size, Config.img_size),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class NFNetModel(nn.Module):
    """
    Model Class for the newly introduced Normalization Free Network (NFNet) Model Architecture
    """
    def __init__(self, num_classes=12, model_name='nfnet_f1', pretrained=True):
        super(NFNetModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.model.head.fc = nn.Linear(self.model.head.fc.in_features, num_classes)
#         self.optimizer=torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0.001)
        

        
    def forward(self, x):
        x = self.model(x)
        return x
    def save(self,optim):
        self.eval()
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': 0,
            }, './nfnet.pth')
    def load(self,optim,path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.loss = checkpoint['loss']

In [None]:
class Plant_data(Dataset):
    def __init__(self, df, num_classes=12, is_train=True, augments=None, img_size=Config.img_size, img_path="../input/plant-pathology-2021-fgvc8/train_images"):
        super().__init__()
        self.df = df.sample(frac=1).reset_index(drop=True)
        self.num_classes = num_classes
        self.is_train = is_train
        self.augments = augments
        self.img_size = img_size
        self.img_path = img_path
        self.image_id = df['image'].values
        self.labels = df['labels'].values
        
        
        
    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        image = cv2.imread(os.path.join(self.img_path, image_id ))
        image = image[:, :, ::-1]
        
        # Augments must be albumentations
        if self.augments:
            img = self.augments(image=image)['image']
        
        
#         label = self.labels[idx]
        label=labels_dict[self.labels[idx]]
        return img, torch.tensor(label)
        
        
    
    def __len__(self):
        return len(self.df)

In [None]:
class Trainer:
    def __init__(self, train_dataloader, valid_dataloader, model, optimizer, loss_fn, val_loss_fn, agc=False, device="cuda:0"):
        """
        Constructor for Trainer class
        """
        self.train = train_dataloader
        self.valid = valid_dataloader
        self.optim = optim
        self.loss_fn = loss_fn
        self.val_loss_fn = val_loss_fn
        self.device = device
        self.agc = agc
    
    def train_one_cycle(self):
        """
        Runs one epoch of training, backpropagation and optimization
        """
        model.train()
        train_prog_bar = tqdm(self.train, total=len(self.train))

        all_train_labels = []
        all_train_preds = []
        all_acc=[]
        
        running_loss = 0
        
        for iteration,xytrain in enumerate(train_prog_bar):
            xtrain = xytrain[0].to(device).float()
            hy= F.one_hot(xytrain[1] ,num_classes=12)
            ytrain = hy.to(device).float()
            
            with autocast():
                # Get predictions
                z = model(xtrain)

                # Training
                train_loss = self.loss_fn(z, ytrain)
                scaler.scale(train_loss).backward()
                
                if self.agc:
                    adaptive_clip_grad(model.parameters(), clip_factor=0.01, eps=1e-3, norm_type=2.0)
                
                scaler.step(self.optim)
                scaler.update()
                self.optim.zero_grad()

                # For averaging and reporting later
                running_loss += train_loss

                # Convert the predictions and corresponding labels to right form
                train_predictions = torch.argmax(z, 1).detach().cpu().numpy()
                train_labels = torch.argmax(ytrain, 1).detach().cpu().numpy()
                acc_per_iteration=np.sum(train_predictions==train_labels)/16                          ##batch size

                # Append current predictions and current labels to a list
                all_train_labels += [train_predictions]
                all_train_preds += [train_labels]
                all_acc.append(acc_per_iteration)
                if iteration%5==0:
                    
                    print('acc per iter ={}'.format(acc_per_iteration))

            # Show the current loss to the progress bar
            train_pbar_desc = f'loss: {train_loss.item():.4f}'
            train_prog_bar.set_description(desc=train_pbar_desc)
        
        # Now average the running loss over all batches and return
        train_running_loss = running_loss / len(self.train)
        print(f"Final Training Loss: {train_running_loss:.4f}")
        print(f"Final Training acc: {np.mean(all_acc):.4f}")
        train_running_acc=np.mean(all_acc)
        
        # Free up memory
        del all_train_labels, all_train_preds, train_predictions, train_labels, xtrain, ytrain, z,all_acc
        
        return train_running_loss, train_running_acc

    def valid_one_cycle(self):
        """
        Runs one epoch of prediction
        """        
        model.eval()
        
        valid_prog_bar = tqdm(self.valid, total=len(self.valid))
        
        with torch.no_grad():
            all_valid_labels = []
            all_valid_preds = []
            all_valid_acc=[]
            
            running_loss = 0
            
            for xval, y in valid_prog_bar:
                xval = xval.to(device).float()
                
                yval=F.one_hot(y ,num_classes=12)
                yval = yval.to(device).float()
                
                val_z = model(xval)
                
                val_loss = self.val_loss_fn(val_z, yval)
                
                running_loss += val_loss.item()
                
                val_pred = torch.argmax(val_z, 1).detach().cpu().numpy()
                val_label =  torch.argmax(yval, 1).detach().cpu().numpy()
                
                
                acc_per_iteration=np.sum(val_pred==val_label)/32
                
                
                
                all_valid_labels += [val_label]
                all_valid_preds += [val_pred]
                all_valid_acc.append(acc_per_iteration)
                
            
                # Show the current loss
                valid_pbar_desc = f"loss: {val_loss.item():.4f}"
                valid_prog_bar.set_description(desc=valid_pbar_desc)
            
            # Get the final loss
            final_loss_val = running_loss / len(self.valid)
            
            # Get Validation Accuracy
            all_valid_labels = np.concatenate(all_valid_labels)
            all_valid_preds = np.concatenate(all_valid_preds)
            
            print(f"Final Validation Loss: {final_loss_val:.4f}")
            print(f"acc: {np.mean(all_valid_acc):.4f}")
            final_val_acc=np.mean(all_valid_acc)
            
            # Free up memory
            del all_valid_labels, all_valid_preds, val_label, val_pred, xval, yval, val_z,all_valid_acc
            
        return (final_loss_val,final_val_acc, model)

In [None]:

model = NFNetModel().to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.001)
loss_fn_train = nn.BCEWithLogitsLoss()
loss_fn_val = nn.BCEWithLogitsLoss()

In [None]:

model = NFNetModel().to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.001)
model.load(optim,'./nfnet.pth')

loss_fn_train = nn.BCEWithLogitsLoss()
loss_fn_val = nn.BCEWithLogitsLoss()

In [None]:
all_epoch_loss_training=[]
all_epoch_loss_val=[]
all_epoch_acc_training=[]
all_epoch_acc_val=[]


train_set = Plant_data(df=train_split, augments=Augments.train_augments)
valid_set = Plant_data(df=valid_split, augments=Augments.valid_augments)

train = DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    pin_memory=False,
    drop_last=False,
    num_workers=8
)

valid = DataLoader(
    valid_set,
    batch_size=32,
    shuffle=False,
    pin_memory=False,
    num_workers=8
)

model = NFNetModel().to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.001)
loss_fn_train = nn.BCEWithLogitsLoss()
loss_fn_val = nn.BCEWithLogitsLoss()

trainer = Trainer(
    train_dataloader=train,
    valid_dataloader=valid,
    model=model,
    optimizer=optim,
    loss_fn=loss_fn_train,
    val_loss_fn=loss_fn_val,
    agc=True,
    device=device,
)

train_losses_nfn = []
valid_losses_nfn = []
train_acc_nfn = []
valid_acc_nfn = []

scaler = GradScaler()

for epoch in range(Config.n_epoch):
    print(f"{'-'*20} EPOCH: {epoch+1}/{Config.n_epoch} {'-'*20}")

    # Run one training epoch
    current_train_loss,current_train_acc = trainer.train_one_cycle()
    train_losses_nfn.append(current_train_loss)
    train_acc_nfn.append(current_train_acc)

    # Run one validation epoch
    current_val_loss,current_val_acc, op_model = trainer.valid_one_cycle()
    valid_losses_nfn.append(current_val_loss)
    valid_acc_nfn.append(current_val_acc)
    

    # Empty CUDA cache
    torch.cuda.empty_cache()
    


In [None]:
model.save(optim)