In [None]:
import sys
package_path = ['../input/timmpackagelatestwhl', '../input/vistion-transformer-pytorch/jx_vit_base_p16_224-80ecf9dd.pth']
for pth in package_path:
    sys.path.append(pth)

In [None]:
import os
import pandas as pd
pd.set_option('display.max_row', None)
pd.set_option('display.max_columns', None)
import albumentations as albu
import matplotlib.pyplot as plt
import json
import seaborn as sns
import cv2
import albumentations as albu
import numpy as np
import random

import torch
import torch.nn as nn
import torchvision.models as models
from torch.optim import Adam, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler

from tqdm import tqdm

!pip install ../input/timmpackagelatestwhl/timm-0.3.4-py3-none-any.whl
import timm

# Load TrainSet

In [None]:
BASE_DIR="../input/cassava-leaf-disease-classification/"
TRAIN_IMAGES_DIR=os.path.join(BASE_DIR,'train_images')
train_df=pd.read_csv(os.path.join(BASE_DIR,'train.csv'))

In [None]:
display(train_df.head())
print(train_df.shape)

# Data Loader and Augmentation

In [None]:
class CassavaDataset(Dataset):
    def __init__(self,df:pd.DataFrame,imfolder:str,train:bool = True, transforms=None):
        self.df=df
        self.imfolder=imfolder
        self.train=train
        self.transforms=transforms
        
    def __getitem__(self,index):
        im_path=os.path.join(self.imfolder,self.df.iloc[index]['image_id'])
        x=cv2.imread(im_path,cv2.IMREAD_COLOR)
        x=cv2.cvtColor(x,cv2.COLOR_BGR2RGB)
        
        if(self.transforms):
            x=self.transforms(image=x)['image']
        
        if(self.train):
            y=self.df.iloc[index]['label']
            return x,y
        else:
            return x
        
    def __len__(self):
        return len(self.df)

In [None]:
train_augs = albu.Compose([
    albu.RandomResizedCrop(height=224, width=224, p=1.0),
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.ShiftScaleRotate(p=0.5),
    albu.Normalize(    
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],),
    ToTensorV2(),
])

valid_augs = albu.Compose([
    albu.Resize(height=224, width=224, p=1.0),
    albu.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],),
    ToTensorV2(),
])

# Helper Functions

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def save_model(model, optimizer, scheduler, fold, epoch, save_every=False, best=False):
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }
    if save_every == True:
        if not (os.path.isdir('./saved_model')): os.mkdir('./saved_model')
        torch.save(state, './saved_model/model_fold_{}_epoch_{}'.format(fold+1, epoch+1))
    if best == True:
        if not (os.path.isdir('./best_model')): os.mkdir('./best_model')
        torch.save(state, './best_model/model_fold_{}_epoch_{}'.format(fold+1, epoch+1))
        
def data_loader(dataset, batch_size, num_workers, phase='train'):
    if phase == 'train':
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
    else: # valid, test
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    return dataloader
        
class EarlyStopping:
    def __init__(self, patience):
        self.patience = patience
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model, optimizer, scheduler, fold, epoch):
        if self.val_loss_min == np.Inf:
            self.val_loss_min = val_loss
        elif val_loss > self.val_loss_min:
            self.counter += 1
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                print('Early Stopping - Fold {} Training is Stopping'.format(fold))
                self.early_stop = True
        else:  # val_loss < val_loss_min
            save_model(model, optimizer, scheduler, fold, epoch, best=True)
            print('*** Validation loss decreased ({} --> {}).  Saving model... ***'.\
                  format(np.round(self.val_loss_min, 6), np.round(val_loss, 6)))
            self.val_loss_min = val_loss
            self.counter = 0

# Create Model

In [None]:
class Model(nn.Module):
    def __init__(self, model_name, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

# train / Validation Functions

In [None]:
def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler):
    model.train()
    lst_out = []
    lst_label = []
    avg_loss = 0
    status = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (images, labels) in status:
        images = images.to(device).float()
        labels = labels.to(device).long()
        with autocast():
            preds = model(images)
            lst_out += [torch.argmax(preds, 1).detach().cpu().numpy()]
            lst_label += [labels.detach().cpu().numpy()]

            loss = loss_fn(preds, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            avg_loss += loss.item() / len(train_loader)
    scheduler.step()
    lst_out = np.concatenate(lst_out); lst_label = np.concatenate(lst_label)
    accuracy = (lst_out==lst_label).mean()
    print('{} epoch - train loss : {}, train accuracy score : {}'.\
          format(epoch + 1, np.round(avg_loss,6), np.round(accuracy*100,2)))

def valid_one_epoch(epoch, model, loss_fn, val_loader, device):
    model.eval()
    lst_val_out = []
    lst_val_label = []
    avg_val_loss = 0
    status = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (images, labels) in status: #status
        val_images = images.to(device).float()
        val_labels = labels.to(device).long()

        val_preds = model(val_images)
        lst_val_out += [torch.argmax(val_preds, 1).detach().cpu().numpy()]
        lst_val_label += [val_labels.detach().cpu().numpy()]
        val_loss = loss_fn(val_preds, val_labels)
        avg_val_loss += val_loss.item() / len(val_loader)
        
    lst_val_out = np.concatenate(lst_val_out); lst_val_label = np.concatenate(lst_val_label)
    accuracy = (lst_val_out==lst_val_label).mean()
    print('{} epoch - valid loss : {}, valid accuracy : {}'.\
          format(epoch + 1, np.round(avg_val_loss, 6), np.round(accuracy*100,2)))
    return avg_val_loss

# Main - Training

In [None]:
if __name__ == '__main__':
    train_batch = 16
    valid_batch = 32
    num_workers = 4
    seed = 42
    split = 5
    epochs = 100
    patience = 5

    n_class = 5
    model_arch = 'vit_base_patch16_224' # 'resnext50_32x4d', 'tf_efficientnet_b4_ns', 'vit_base_patch16_224'
    weight_path = '../input/vistion-transformer-pytorch/jx_vit_base_p16_224-80ecf9dd.pth'
    device = 'cuda'

    seed_everything(seed)
    X_train = train_df.iloc[:, :-1]; Y_train = train_df.iloc[:, -1]
    cv = StratifiedKFold(n_splits=split, random_state=seed, shuffle=True)
    for fold, (train_index, val_index) in enumerate(cv.split(X_train, Y_train)):
        print('---------- Fold {} is training ----------'.format(fold + 1))
        train_x, train_y = X_train.iloc[train_index], Y_train[train_index]
        val_x, val_y = X_train.iloc[val_index], Y_train[val_index]

        train_dataset=CassavaDataset(df=pd.concat([train_x, train_y], axis=1), imfolder=TRAIN_IMAGES_DIR, train=True, transforms=train_augs)
        valid_dataset=CassavaDataset(df=pd.concat([val_x, val_y], axis=1), imfolder=TRAIN_IMAGES_DIR, train=True, transforms=valid_augs)
        train_loader = data_loader(train_dataset, train_batch, num_workers, phase='train')
        valid_loader = data_loader(valid_dataset, valid_batch, num_workers, phase='valid')

        model = Model(model_arch, n_class, pretrained=False).to(device)
        model.load_state_dict(torch.load(weight_path), strict=False)
        loss_tr = nn.CrossEntropyLoss().to(device); loss_fn = nn.CrossEntropyLoss().to(device)
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
        scaler = GradScaler()
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1)
        early_stopping = EarlyStopping(patience=patience)

        for epoch in range(epochs):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler)
            save_model(model, optimizer, scheduler, fold, epoch)
            with torch.no_grad():
                val_loss = valid_one_epoch(epoch, model, loss_fn, valid_loader, device)
                early_stopping(val_loss, model, optimizer, scheduler, fold, epoch)
                if early_stopping.early_stop:
                    break

        del model, optimizer, train_dataset, valid_dataset, train_loader, valid_loader, scheduler, scaler
        torch.cuda.empty_cache()