### Install Libraries

In [None]:
import sys
sys.path.append('../input/tez-lib')
sys.path.append('../input/timmmaster')

### Import required packages

In [None]:
import argparse
import os

import cv2
import albumentations
import albumentations.pytorch
import pandas as pd
import numpy as np

import tez
import timm
import torch
import torch.nn as nn
import torchvision

from sklearn import metrics, model_selection, preprocessing
from tez.callbacks import EarlyStopping
from tez.datasets import ImageDataset
from torch.nn import functional as F

import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb

import warnings
warnings.filterwarnings('ignore')

### CFG class

In [None]:
class CFG:
    image_size = 512
    target_size = 5
    target_col = 'label'
    model_name = 'resnext50_32x4d'
    epochs = 15
    batch_size = 16
    n_fold = 5
    trn_fold = [0,1,2,3,4]

### Create folds

In [None]:
INPUT_PATH = "../input/cassava-leaf-disease-classification/"
IMAGE_PATH = "../input/cassava-leaf-disease-classification/train_images/"
OUTPUT_DIR = './'

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
df['kfold'] = -1
df = df.sample(frac = 1).reset_index(drop=True)
y = df['label'].values

kf = model_selection.StratifiedKFold(n_splits = CFG.n_fold,random_state=42)
for fold_, (train_idx , test_idx) in enumerate(kf.split(X= df, y=y)):
    df.loc[test_idx, "kfold"] = fold_

df.to_csv('./train_folds.csv',index = False)

In [None]:
df['label'].value_counts()

### Logger

In [None]:
def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

### Define augmentation

In [None]:
def generate_transforms():

    train_aug = albumentations.Compose(
        [
            albumentations.RandomResizedCrop(CFG.image_size, CFG.image_size),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5
            ),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5),
            albumentations.pytorch.ToTensorV2()
        ],
        p=1.0,
    )

    valid_aug = albumentations.Compose(
        [
            albumentations.Resize(CFG.image_size, CFG.image_size, p=1.0),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
            albumentations.pytorch.ToTensorV2()
        ],
        p=1.0,
    )
    return {"train_transforms": train_aug, "valid_transforms": valid_aug}

### Define Pytorch Dataset class

In [None]:
class FlowerDataset:
    def __init__(self, image_paths, targets, augmentations):
        self.image_paths = image_paths
        self.targets = targets
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        targets = self.targets[item]
        
        image = cv2.imread(self.image_paths[item])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        augmented = self.augmentations(image = image)
        image = augmented["image"]
        
        return {
            "image": image,
            "targets": targets,
        }

### Define model class (the most important part) -> this is the part i am going to learn hard way

In [None]:
import warnings

import psutil
import torch
import torch.nn as nn
from tez import enums
from tez.callbacks import CallbackRunner
from tez.utils import AverageMeter
from tqdm import tqdm
import time

class LeafModel(tez.Model):
    def __init__(self, pretrained = True):
        super().__init__()
        self.model = timm.create_model(model_name = CFG.model_name, pretrained = pretrained)
        self.n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.n_features, CFG.target_size)
        
        self.step_scheduler_after = "epoch"
        self.step_scheduler_metric = "valid_accuracy"
        
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay = 1e-6)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode = 'min', factor = 0.2,patience = 5, eps =1e-6, verbose = True)
        return sch

    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        outputs = self.model(image)
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None
    
    def fit(
        self,
        train_dataset,
        valid_dataset=None,
        train_sampler=None,
        valid_sampler=None,
        device="cuda",
        epochs=10,
        train_bs=16,
        valid_bs=16,
        n_jobs=8,
        callbacks=None,
        fp16=False,
        train_collate_fn=None,
        valid_collate_fn=None,
        train_shuffle=True,
        valid_shuffle=False,
        accumulation_steps=1,
        clip_grad_norm=None,
    ):
        """
        The model fit function. Heavily inspired by tf/keras, this function is the core of Tez and this is the only
        function you need to train your models.

        """
        if device == "tpu":
            if XLA_AVAILABLE is False:
                raise RuntimeError("XLA is not available. Please install pytorch_xla")
            else:
                self.using_tpu = True
                fp16 = False
                device = xm.xla_device()
        self._init_model(
            device=device,
            train_dataset=train_dataset,
            valid_dataset=valid_dataset,
            train_sampler=train_sampler,
            valid_sampler=valid_sampler,
            train_bs=train_bs,
            valid_bs=valid_bs,
            n_jobs=n_jobs,
            callbacks=callbacks,
            fp16=fp16,
            train_collate_fn=train_collate_fn,
            valid_collate_fn=valid_collate_fn,
            train_shuffle=train_shuffle,
            valid_shuffle=valid_shuffle,
            accumulation_steps=accumulation_steps,
            clip_grad_norm=clip_grad_norm,
        )

        for epoch in range(epochs):
            start_time = time.time()
            self.train_state = enums.TrainingState.EPOCH_START
            self.train_state = enums.TrainingState.TRAIN_EPOCH_START
            train_loss = self.train_one_epoch(self.train_loader)
            self.train_state = enums.TrainingState.TRAIN_EPOCH_END
            if self.valid_loader:
                self.train_state = enums.TrainingState.VALID_EPOCH_START
                valid_loss = self.validate_one_epoch(self.valid_loader)
                self.train_state = enums.TrainingState.VALID_EPOCH_END
            
            elapsed = time.time() - start_time
            LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {train_loss:.4f}  avg_val_loss: {valid_loss:.4f} time: {elapsed:.0f}s')
            
            if self.scheduler:
                if self.step_scheduler_after == "epoch":
                    if self.step_scheduler_metric is None:
                        self.scheduler.step()
                    else:
                        step_metric = self.name_to_metric(self.step_scheduler_metric)
                        self.scheduler.step(step_metric)
            self.train_state = enums.TrainingState.EPOCH_END
            if self._model_state.value == "end":
                break
            self.current_epoch += 1
        self.train_state = enums.TrainingState.TRAIN_END

### Train on Multiple folds

In [None]:
def predict(valid_dataset,fold):
    model = LeafModel(pretrained = False)
    model.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.bin', device="cuda", weights_only=True)
    preds = model.predict(valid_dataset, batch_size=2*CFG.batch_size )
    
    temp_preds = None
    for p in preds:
        if temp_preds is None:
            temp_preds = p
        else:
            temp_preds = np.vstack((temp_preds, p))
    return temp_preds.argmax(axis=1)

In [None]:
def train(fold):
    LOGGER.info(f"========== fold: {fold} training ==========")
    
    dfx = pd.read_csv("./train_folds.csv")
    df_train = dfx[dfx.kfold != fold].reset_index(drop=True)
    df_valid = dfx[dfx.kfold == fold].reset_index(drop=True)

    train_image_paths = [os.path.join(IMAGE_PATH, x) for x in df_train.image_id.values]
    valid_image_paths = [os.path.join(IMAGE_PATH, x) for x in df_valid.image_id.values]
    train_targets = df_train.label.values
    valid_targets = df_valid.label.values

    train_dataset = FlowerDataset(
        image_paths=train_image_paths,
        targets=train_targets,
        augmentations=generate_transforms()["train_transforms"],
    )

    valid_dataset = FlowerDataset(
        image_paths=valid_image_paths,
        targets=valid_targets,
        augmentations=generate_transforms()["valid_transforms"],
    )
    
    model = LeafModel()
    es = EarlyStopping(
        monitor="valid_accuracy", 
        model_path=OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.bin', 
        patience=2, 
        mode="max",
        save_weights_only=True
    )
    model.fit(
        train_dataset,
        valid_dataset=valid_dataset,
        train_bs=CFG.batch_size,
        valid_bs=2*CFG.batch_size,
        device="cuda",
        epochs=CFG.epochs,
        callbacks=[es],
        fp16=True,
    )
    df_valid['preds'] = predict(valid_dataset, fold)
    return df_valid

In [None]:
def get_score(y_true, y_pred):
        score = metrics.accuracy_score(y_true, y_pred) 
        return score

def get_result(result_df):
        preds = result_df['preds'].values
        labels = result_df[CFG.target_col].values
        score = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.4f}')

oof_df = pd.DataFrame()
for fold in range(CFG.n_fold):
    _oof_df = train(fold)
    oof_df = pd.concat([oof_df, _oof_df])
    LOGGER.info(f"========== fold: {fold} result ==========")
    get_result(_oof_df)

# CV result
LOGGER.info(f"========== CV ==========")
get_result(oof_df)
# save result
oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)

In [None]:
# with open('./train.log') as f:
#     f = f.readlines()
# for line in f:
#     print(line)