### A Pytorch Template for all users (beginners/advanced)

The goal of this notebook is to allow users to use a script that runs at the command line and allows the user to try out different options for experimentation. 

Pytorch AMP is used for training and augmentations are via `Albumentations` (https://github.com/albumentations-team/albumentations) 

### Specifically, the features supported are:

1) Get list of models via `--find_model` eg. `--find_model se*resnet*50`

2) Testing a sample of data via `--n_samples` to make sure main code works fine for changes

3) Testing augmentation visually via `--test_loader` and changing the augments in the `get_sample_transforms` function. I noticed that a lot of people dont really look at the details of the augmentations and this helps quickly look at what they are really doing

4) Providing a scheduler or loss function via `--scheduler <scheduler name>` / `--loss_fn <loss function name>` Supported loss functions are 
*  CrossEntropyLoss
*  SmoothCrossEntropyLoss
*  SymmetricCrossEntropy
*  BCEWithLogitsLoss
*  TruncatedLoss
*  TaylorCrossEntropyLoss
*  BiTemperedLoss
*  FocalCosineLoss

5) Optionally get noisy labels using `CleanLab` (https://github.com/cgnorthcutt/cleanlab)

6) Save best model for each epoch and each fold

7) Support for upsampling/downsampling 

8) Optional Cutmix (per batch with 25% probability) 

9) Optional SVM head on CNN model 

### The script saves all states for reproducibility in the future - eg. List. ofcommand line arguments, list of config options, augmentations performed, scores etc

### With the above functionality, users should be able to try out different architectures and parameters for experimentation. 

### Sample Command line usage: 
`python3 pytorch_amp.py --train --model tf_efficientnet_b4_ns --model_dir effnetb4_2019_epochs --img_size 512 --batch_size 16 --loss_fn BiTemperedLoss --augment` 


### Credits:

This kernel would not have been possible without the contributions of various authors below. 

Loss functions:
https://www.kaggle.com/piantic/train-cassava-starter-using-various-loss-funcs

Pytorch AMP: 
https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug

Cutmix:
https://www.kaggle.com/ar2017/pytorch-efficientnet-train-aug-cutmix-fmix

Cleanlab Tutorial: 
https://www.kaggle.com/telljoy/noisy-label-eda-with-cleanlab


In [None]:
import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from pathlib import Path
from sklearn.utils import resample
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
from torch.utils.data import Dataset
import time
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from torch.nn import functional as F
import random
from sklearn.preprocessing import LabelEncoder
import cv2
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout,
    ShiftScaleRotate, CenterCrop, Resize, RandomCrop
)
from albumentations.pytorch import ToTensorV2
import timm
from torch.nn.modules.loss import _WeightedLoss
from scipy.special import softmax
import argparse
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

from cleanlab.pruning import get_noise_indices
from cleanlab.classification import LearningWithNoisyLabels
from sklearn.base import BaseEstimator
import sklearn
from sklearn import svm
import pickle


SEED = 1234

# Base data path
DATA_PATH = "../input/cassava-leaf-disease-classification/"
# Where the images/audio are stored
FILE_PATH = "../input/cassava-leaf-disease-classification/train_images/"
TEST_FILE_PATH = "../input/cassava-leaf-disease-classification/test_images/"

TRAIN_DF_FILE = 'train_2019_2020_clean.csv'
TARGET_COL = 'label'
ID_COL = 'image_id'

# Directory to store experiment results
BASE_MODEL_FOLDER = "./models/"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
NUM_WORKERS = 4
SMOOTHING = 0.05


# model_names = timm.list_models('*inception*')
# print(model_names)

def detect_leaf(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    # find the brown color
    mask_brown = cv2.inRange(hsv, (8, 60, 20), (30, 255, 200))
    # find the yellow and green color in the leaf
    # 21, 86
    mask_yellow_green = cv2.inRange(hsv, (10, 39, 64), (86, 255, 255))
    # find any of the three colors(green or brown or yellow) in the image
    mask = cv2.bitwise_or(mask_yellow_green, mask_brown)
    # Bitwise-AND mask and original image
    res = cv2.bitwise_and(img, img, mask=mask)
    return res


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha), 0.3, 0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)

    return new_data, targets


# CutMix
class CutMixCollator:
    def __call__(self, batch):
        # batch = torch.utils.data.dataloader.default_collate(batch)
        batch = cutmix(batch)
        return batch


class CutMixCriterion(nn.Module):
    def __init__(self, criterion):
        super(CutMixCriterion, self).__init__()
        self.criterion = criterion

    def forward(self, preds, targets):
        targets1 = targets[0].to(device)
        targets2 = targets[1].to(device)
        lam = targets[2]
        return lam * self.criterion.forward(
            preds, targets1) + (1 - lam) * self.criterion.forward(preds, targets2)


# Focal cosine loss
class FocalCosineLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, xent=.1):
        super(FocalCosineLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

        self.xent = xent

        self.y = torch.Tensor([1]).cuda()

    def forward(self, input, target, reduction="mean"):
        cosine_loss = F.cosine_embedding_loss(input, F.one_hot(target, num_classes=input.size(-1)), self.y,
                                              reduction=reduction)

        cent_loss = F.cross_entropy(F.normalize(input), target, reduce=False)
        pt = torch.exp(-cent_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * cent_loss

        if reduction == "mean":
            focal_loss = torch.mean(focal_loss)

        return cosine_loss + self.xent * focal_loss


# BiTemperedLoss
# Code taken from https://github.com/fhopfmueller/bi-tempered-loss-pytorch/blob/master/bi_tempered_loss_pytorch.py

def log_t(u, t):
    """Compute log_t for `u'."""
    if t == 1.0:
        return u.log()
    else:
        return (u.pow(1.0 - t) - 1.0) / (1.0 - t)


def exp_t(u, t):
    """Compute exp_t for `u'."""
    if t == 1:
        return u.exp()
    else:
        return (1.0 + (1.0 - t) * u).relu().pow(1.0 / (1.0 - t))


def compute_normalization_fixed_point(activations, t, num_iters):
    """Returns the normalization value for each example (t > 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same shape as activation with the last dimension being 1.
    """
    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations_step_0 = activations - mu

    normalized_activations = normalized_activations_step_0

    for _ in range(num_iters):
        logt_partition = torch.sum(
            exp_t(normalized_activations, t), -1, keepdim=True)
        normalized_activations = normalized_activations_step_0 * \
                                 logt_partition.pow(1.0 - t)

    logt_partition = torch.sum(
        exp_t(normalized_activations, t), -1, keepdim=True)
    normalization_constants = - log_t(1.0 / logt_partition, t) + mu

    return normalization_constants


def compute_normalization_binary_search(activations, t, num_iters):
    """Returns the normalization value for each example (t < 1.0).
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (< 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu, _ = torch.max(activations, -1, keepdim=True)
    normalized_activations = activations - mu

    effective_dim = \
        torch.sum(
            (normalized_activations > -1.0 / (1.0 - t)).to(torch.int32),
            dim=-1, keepdim=True).to(activations.dtype)

    shape_partition = activations.shape[:-1] + (1,)
    lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device)
    upper = -log_t(1.0 / effective_dim, t) * torch.ones_like(lower)

    for _ in range(num_iters):
        logt_partition = (upper + lower) / 2.0
        sum_probs = torch.sum(
            exp_t(normalized_activations - logt_partition, t),
            dim=-1, keepdim=True)
        update = (sum_probs < 1.0).to(activations.dtype)
        lower = torch.reshape(
            lower * update + (1.0 - update) * logt_partition,
            shape_partition)
        upper = torch.reshape(
            upper * (1.0 - update) + update * logt_partition,
            shape_partition)

    logt_partition = (upper + lower) / 2.0
    return logt_partition + mu


class ComputeNormalization(torch.autograd.Function):
    """
    Class implementing custom backward pass for compute_normalization. See compute_normalization.
    """

    @staticmethod
    def forward(ctx, activations, t, num_iters):
        if t < 1.0:
            normalization_constants = compute_normalization_binary_search(activations, t, num_iters)
        else:
            normalization_constants = compute_normalization_fixed_point(activations, t, num_iters)

        ctx.save_for_backward(activations, normalization_constants)
        ctx.t = t
        return normalization_constants

    @staticmethod
    def backward(ctx, grad_output):
        activations, normalization_constants = ctx.saved_tensors
        t = ctx.t
        normalized_activations = activations - normalization_constants
        probabilities = exp_t(normalized_activations, t)
        escorts = probabilities.pow(t)
        escorts = escorts / escorts.sum(dim=-1, keepdim=True)
        grad_input = escorts * grad_output

        return grad_input, None, None


def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example.
    Backward pass is implemented.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return ComputeNormalization.apply(activations, t, num_iters)


def tempered_sigmoid(activations, t, num_iters=5):
    """Tempered sigmoid function.
    Args:
      activations: Activations for the positive class for binary classification.
      t: Temperature tensor > 0.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    internal_activations = torch.stack([activations,
                                        torch.zeros_like(activations)],
                                       dim=-1)
    internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
    return internal_probabilities[..., 0]


def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature > 1.0.
      num_iters: Number of iterations to run the method.
    Returns:
      A probabilities tensor.
    """
    if t == 1.0:
        return activations.softmax(dim=-1)

    normalization_constants = compute_normalization(activations, t, num_iters)
    return exp_t(activations - normalization_constants, t)


def bi_tempered_binary_logistic_loss(activations,
                                     labels,
                                     t1,
                                     t2,
                                     label_smoothing=0.0,
                                     num_iters=5,
                                     reduction='mean'):
    """Bi-Tempered binary logistic loss.
    Args:
      activations: A tensor containing activations for class 1.
      labels: A tensor with shape as activations, containing probabilities for class 1
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing
      num_iters: Number of iterations to run the method.
    Returns:
      A loss tensor.
    """
    internal_activations = torch.stack([activations,
                                        torch.zeros_like(activations)],
                                       dim=-1)
    internal_labels = torch.stack([labels.to(activations.dtype),
                                   1.0 - labels.to(activations.dtype)],
                                  dim=-1)
    return bi_tempered_logistic_loss(internal_activations,
                                     internal_labels,
                                     t1,
                                     t2,
                                     label_smoothing=label_smoothing,
                                     num_iters=num_iters,
                                     reduction=reduction)


def bi_tempered_logistic_loss(activations,
                              labels,
                              t1,
                              t2,
                              label_smoothing=0.0,
                              num_iters=5,
                              reduction='mean'):
    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot),
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape) < len(activations.shape):  # not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = (1 - label_smoothing * num_classes / (num_classes - 1)) \
                        * labels_onehot + \
                        label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
                  - labels_onehot * log_t(probabilities, t1) \
                  - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
                  + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim=-1)  # sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()


class BiTemperedLogisticLoss(nn.Module):
    def __init__(self, t1, t2, smoothing=0.0):
        super(BiTemperedLogisticLoss, self).__init__()
        self.t1 = t1
        self.t2 = t2
        self.smoothing = smoothing

    def forward(self, logit_label, truth_label):
        loss_label = bi_tempered_logistic_loss(
            logit_label, truth_label,
            t1=self.t1, t2=self.t2,
            label_smoothing=self.smoothing,
            reduction='none'
        )

        loss_label = loss_label.mean()
        return loss_label


## End bitempered loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=5, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))


class TaylorSoftmax(nn.Module):
    '''
    This is the autograd version
    '''

    def __init__(self, dim=1, n=2):
        super(TaylorSoftmax, self).__init__()
        assert n % 2 == 0
        self.dim = dim
        self.n = n

    def forward(self, x):
        '''
        usage similar to nn.Softmax:
            >>> mod = TaylorSoftmax(dim=1, n=4)
            >>> inten = torch.randn(1, 32, 64, 64)
            >>> out = mod(inten)
        '''
        fn = torch.ones_like(x)
        denor = 1.
        for i in range(1, self.n + 1):
            denor *= i
            fn = fn + x.pow(i) / denor
        out = fn / fn.sum(dim=self.dim, keepdims=True)
        return out


class TaylorCrossEntropyLoss(nn.Module):
    def __init__(self, n=2, ignore_index=-1, reduction='mean', smoothing=0.05):
        super(TaylorCrossEntropyLoss, self).__init__()
        assert n % 2 == 0
        self.taylor_softmax = TaylorSoftmax(dim=1, n=n)
        self.reduction = reduction
        self.ignore_index = ignore_index
        self.lab_smooth = LabelSmoothingLoss(Config.img_size, smoothing=smoothing)

    def forward(self, logits, labels):
        log_probs = self.taylor_softmax(logits).log()
        # loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
        #        ignore_index=self.ignore_index)
        loss = self.lab_smooth(log_probs, labels)
        return loss


# https://github.com/AlanChou/Truncated-Loss/blob/master/TruncatedLoss.py
class TruncatedLoss(nn.Module):

    def __init__(self, q=0.7, k=0.5, trainset_size=50000):
        super(TruncatedLoss, self).__init__()
        self.q = q
        self.k = k
        self.weight = torch.nn.Parameter(data=torch.ones(trainset_size, 1), requires_grad=False)

    def forward(self, logits, targets, indexes):
        p = F.softmax(logits, dim=1)
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))

        loss = ((1 - (Yg ** self.q)) / self.q) * self.weight[indexes] - ((1 - (self.k ** self.q)) / self.q) * \
               self.weight[indexes]
        loss = torch.mean(loss)

        return loss

    def update_weight(self, logits, targets, indexes):
        p = F.softmax(logits, dim=1)
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        Lq = ((1 - (Yg ** self.q)) / self.q)
        Lqk = np.repeat(((1 - (self.k ** self.q)) / self.q), targets.size(0))
        Lqk = torch.from_numpy(Lqk).type(torch.cuda.FloatTensor)
        Lqk = torch.unsqueeze(Lqk, 1)

        condition = torch.gt(Lqk, Lq)
        self.weight[indexes] = condition.type(torch.cuda.FloatTensor)


# source: https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/173733
class SmoothCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=SMOOTHING):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    @staticmethod
    def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=SMOOTHING):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                                  device=targets.device) \
                .fill_(smoothing / (n_classes - 1)) \
                .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing)
        return targets

    def forward(self, inputs, targets):
        targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
                                                         self.smoothing)
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if self.reduction == 'sum':
            loss = loss.sum()
        elif self.reduction == 'mean':
            loss = loss.mean()

        return loss


class SymmetricCrossEntropy(nn.Module):

    def __init__(self, alpha=0.1, beta=1.0, num_classes= 5):
        super(SymmetricCrossEntropy, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.num_classes = num_classes

    def forward(self, logits, targets, reduction='mean'):
        onehot_targets = torch.eye(self.num_classes)[targets].cuda()
        ce_loss = F.cross_entropy(logits, targets, reduction=reduction)
        rce_loss = (-onehot_targets*logits.softmax(1).clamp(1e-7, 1.0).log()).sum(1)
        if reduction == 'mean':
            rce_loss = rce_loss.mean()
        elif reduction == 'sum':
            rce_loss = rce_loss.sum()
        return self.alpha * ce_loss + self.beta * rce_loss

def mono_to_color(X, eps=1e-6, mean=None, std=None):
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)

    # Normalize to [0, 255]
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V


def resize(image, size=None):
    if size is not None:
        h, w, _ = image.shape
        new_w, new_h = int(w * size / h), size
        image = cv2.resize(image, (new_w, new_h))

    return image


def normalize(image, mean=None, std=None):
    image = image / 255.0
    if mean is not None and std is not None:
        image = (image - mean) / std
    # return np.moveaxis(image, 2, 0).astype(np.float32)
    return image.astype(np.float32)


# Data loader
class CustomDataset(Dataset):
    def __init__(self, df, file_path, train=True, transforms=None):
        self.train = train
        self.df = df
        self.file_path = file_path
        self.filename = df[ID_COL].values
        self.transforms = transforms
        self.labels = df[TARGET_COL].values

    def __len__(self):
        return len(self.filename)

    def __getitem__(self, idx: int):
        # Return audio and sampling rate
        file = self.file_path + self.filename[idx]
        if not os.path.exists(file):
            print(file)
        image = cv2.imread(file)
        orig_image = image.copy()

        # Special user defined mask
        if Config.cv_mask:
            image = detect_leaf(image)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Normalize if no augmentation. Test set also needs this
        if not self.transforms:
            image = cv2.resize(image, dsize=(Config.img_size, Config.img_size), interpolation=cv2.INTER_LINEAR)
            image = normalize(image, mean=None, std=None)
        # Augment
        else:
            image = cv2.resize(image, dsize=(Config.img_size, Config.img_size), interpolation=cv2.INTER_LINEAR)
            image = self.transforms(image=image)['image']

        # Uncomment to switch channel to first dimension if image shape is H x W x C
        # image = np.transpose(image, axes=[2,0,1])

        # Return image and raw data
        if self.train:
            return image, ONE_HOT[self.df[TARGET_COL][idx]]
        elif Config.test_loader:
            return image, orig_image
        else:
            return image


# Test data loader
# audio_data = AudioDataset(df=train_df, params=AudioParams, audio_path=AUDIO_PATH)
# audio_data[0]
# Torch utils
def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results

    Arguments:
        seed {int} -- Number of the seed
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # False


def save_model_params(val_scores, fold, cp_folder, use_svm=False):
    val_scores = np.array(val_scores).mean()
    params_file = os.path.join(cp_folder, 'params.txt')
    with open(params_file, 'a+') as f:
        if use_svm:
            f.write(f'Fold: {fold} Mean Val SVM Score: {val_scores}\n')
        else:
            f.write(f'Fold: {fold} Mean Val Score: {val_scores}\n')


def save_model_weights(model, filename, verbose=1, cp_folder=""):
    """
    Saves the weights of a PyTorch model

    Arguments:
        model {torch module} -- Model to save the weights of
        filename {str} -- Name of the checkpoint

    Keyword Arguments:
        verbose {int} -- Whether to display infos (default: {1})
        cp_folder {str} -- Folder to save to (default: {''})
    """
    if verbose:
        print(f"\n -> Saving weights to {os.path.join(cp_folder, filename)}\n")
    torch.save(model.state_dict(), os.path.join(cp_folder, filename))


def load_model_weights(model, filename, verbose=1, cp_folder=""):
    """
    Loads the weights of a PyTorch model. The exception handles cpu/gpu incompatibilities

    Arguments:
        model {torch module} -- Model to load the weights to
        filename {str} -- Name of the checkpoint

    Keyword Arguments:
        verbose {int} -- Whether to display infos (default: {1})
        cp_folder {str} -- Folder to load from (default: {''})

    Returns:
        torch module -- Model with loaded weights
    """
    if verbose:
        print(f"\n -> Loading weights from {os.path.join(cp_folder, filename)}\n")
    try:
        model.load_state_dict(os.path.join(cp_folder, filename), strict=True)
    except BaseException:
        model.load_state_dict(
            torch.load(os.path.join(cp_folder, filename), map_location="cpu"),
            strict=True,
        )
    return model


def count_parameters(model, all=False):
    """
    Count the parameters of a model

    Arguments:
        model {torch module} -- Model to count the parameters of

    Keyword Arguments:
        all {bool} -- Whether to include not trainable parameters in the sum (default: {False})

    Returns:
        int -- Number of parameters
    """
    if all:
        return sum(p.numel() for p in model.parameters())
    else:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_metric(truth, pred, avg="micro", metrics=["f1"]):
    if len(truth.shape) == 1:
        truth = ONE_HOT[truth]
    pred_class = np.argmax(pred, axis=1)
    truth_class = np.argmax(truth, axis=1)
    results = []
    for metric in metrics:
        if metric == 'f1':
            results.append(sklearn.metrics.f1_score(truth_class, pred_class, average=avg))
        elif metric == 'acc':
            results.append(sklearn.metrics.accuracy_score(truth_class, pred_class))
        elif metric == 'mae':
            results.append(sklearn.metrics.mean_absolute_error(truth_class, pred_class))
    return results


def smooth_label(y, alpha=0.01):
    y = y * (1 - alpha)
    y[y == 0] = alpha
    return y


def get_model(name, num_classes=1):
    model = timm.create_model(name, pretrained=True)
    if 'resne' in name or 'inception' in name:
        nb_ft = model.fc.in_features
        del model.fc
        model.fc = nn.Linear(nb_ft, num_classes)
    elif 'vit' in name:
        nb_ft = model.head.in_features
        del model.head
        model.head = nn.Linear(nb_ft, num_classes)
    else:
        #for param in model.parameters():
        #    param.requires_grad = False
        nb_ft = model.classifier.in_features
        del model.classifier
        model.classifier = nn.Linear(nb_ft, num_classes)
    return model


def train_one_epoch(epoch, model, scaler, criterion, optimizer, train_loader, device, scheduler=None,
                    schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None
    base_criterion = criterion
    image_preds_all = []
    image_labels_all = []

    for step, (images, image_labels) in enumerate(train_loader):

        mix_decision = np.random.rand()

        # Cutmix to modify the images
        if Config.cutmix and mix_decision < 0.25:
            images, image_labels_tuple = cutmix(images, image_labels, 1.)
            image_labels = image_labels_tuple[0]

        images = images.to(device)
        # Probablities
        image_labels = image_labels.to(device).long()
        # Get label of max prob
        image_labels = torch.max(image_labels, 1)[1]

        with autocast():
            image_preds = model(images)

            # For SVM aggregate preds and labels to fit
            if Config.use_svm:
                image_preds_all.append(image_preds.cpu().detach().numpy())
                image_labels_all.append(image_labels.cpu().detach().numpy())

            # Change criterion if cutmix is enabled for this batch
            if Config.cutmix and mix_decision < 0.25:
                criterion = CutMixCriterion(base_criterion).to(device)
                loss = criterion(image_preds, image_labels_tuple)

            else:
                criterion = base_criterion
                loss = criterion(image_preds, image_labels)

        scaler.scale(loss).backward()

        if running_loss is None:
            running_loss = loss.item()
        else:
            running_loss = running_loss * .99 + loss.item() * .01

        if ((step + 1) % Config.accum_iter == 0) or ((step + 1) == len(train_loader)):
            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if scheduler is not None and schd_batch_update:
                scheduler.step()

    if Config.use_svm:
        image_preds_all = np.concatenate(image_preds_all)
        image_labels_all = np.concatenate(image_labels_all)
        svm_clf.fit(image_preds_all, image_labels_all)

    train_time = time.time() - t
    if ((step + 1) % Config.verbose == 0) or ((step + 1) == len(train_loader)):
        print(f'Epoch {epoch + 1}: Train loss: {running_loss:.4f} Time: {train_time:.3f} secs')

    if scheduler is not None and not schd_batch_update:
        scheduler.step()


def valid_one_epoch(epoch, model, criterion, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    image_preds_svm_all = []
    svm_metric_score = 0

    for step, (images, image_labels) in enumerate(val_loader):
        image_labels = image_labels.to(device).long()
        image_labels = torch.max(image_labels, 1)[1]
        images = images.to(device)

        image_preds = model(images)  # output = model(input)
        # print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]

        loss = criterion(image_preds, image_labels)

        loss_sum += loss.item() * image_preds.shape[0]
        sample_num += image_preds.shape[0]

        if Config.use_svm:
            svm_preds = svm_clf.predict(image_preds.cpu().detach().numpy())
            image_preds_svm_all += [svm_preds]

    if ((step + 1) % Config.verbose == 0) or ((step + 1) == len(val_loader)):
        print(f'Epoch {epoch + 1}: Val loss: {loss_sum / sample_num:.4f}')

    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    if Config.use_svm:
        image_preds_svm_all = np.concatenate(image_preds_svm_all)

    metric_score = (image_preds_all == image_targets_all).mean()
    print('Validation multi-class accuracy = {:.4f}'.format(metric_score))

    if Config.use_svm:
        svm_metric_score = (image_preds_svm_all == image_targets_all).mean()
        print('SVM validation accuracy = {:.5f}'.format(svm_metric_score))

    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum / sample_num)
        else:
            scheduler.step()

    return metric_score, svm_metric_score


def predict(model, dataset, batch_size=64, infer=False):
    """
    Usual torch predict function

    Arguments:
        model {torch model} -- Model to predict with
        dataset {torch dataset} -- Dataset to predict with on

    Keyword Arguments:
        batch_size {int} -- Batch size (default: {32})

    Returns:
        numpy array -- Predictions
    """
    model.eval()
    preds = np.empty((0, NUM_CLASSES))

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS
    )
    with torch.no_grad():
        if not infer:
            for x, _ in loader:
                # x = x.type(torch.LongTensor)
                y_pred = model(x.detach())
                preds = np.concatenate([preds, y_pred.cpu().numpy()])
                # preds = np.concatenate([preds, torch.sigmoid(y_pred).cpu().numpy()])
        else:
            for x in loader:
                y_pred = model(x.to(device).detach())
                preds = np.concatenate([preds, y_pred.cpu().numpy()])
                # preds = np.concatenate([preds, torch.sigmoid(y_pred).cpu().numpy()])
    return preds


class ImageClassifier(nn.Module):
    def __init__(self, model, n_class, infer=False):
        super().__init__()
        if not infer:
            self.model = get_model(model, n_class)
        else:
            self.model = model

    def forward(self, x):
        x = self.model(x)
        return x


def train_model(config, df_train, df_val, fold):
    print(f"{len(df_train)} training samples ")
    print(f"{len(df_val)} validation samples ")
    seed_everything(config.seed)

    # Pretrained model
    model = get_model(config.selected_model, num_classes=NUM_CLASSES).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Found {torch.cuda.device_count()} GPUs. Using DataParallel")
        model = nn.DataParallel(model)

    model.zero_grad()

    df_train = df_train.reset_index().drop('index', axis=1)
    df_val = df_val.reset_index().drop('index', axis=1)

    # Upsample train set to prevent leakage in val set
    if args.upsample:
        majority_class = 3
        unique_classes = df_train[TARGET_COL].unique()
        # Initialize with majority class
        final_df = df_train[df_train[TARGET_COL] == majority_class]
        # number of samples in majority class
        samples = int(args.upsample * final_df.shape[0])
        print(f"Before sampling:")
        print(df_train[TARGET_COL].value_counts())
        print(f"Sampling to match {samples} samples")
        for class_val in unique_classes:
            if class_val == majority_class:
                continue
            df_train_min = df_train[df_train[TARGET_COL] == class_val]
            upsample_df = resample(df_train_min,
                                   replace=True,  # sample without replacement
                                   n_samples=samples,  # to match minority class
                                   random_state=123)
            final_df = pd.concat([final_df, upsample_df])
        df_train = final_df
        print('After upsampling:')
        print(df_train[TARGET_COL].value_counts())
        df_train = df_train.reset_index().drop('index', axis=1)

    epochs = config.epochs

    # Datasets
    if Config.augment:
        train_dataset = CustomDataset(df=df_train, file_path=FILE_PATH, train=True,
                                      transforms=get_train_transforms())
        val_dataset = CustomDataset(df=df_val, file_path=FILE_PATH, train=True, transforms=get_valid_transforms())
    else:
        train_dataset = CustomDataset(df=df_train, file_path=FILE_PATH, train=True)
        val_dataset = CustomDataset(df=df_val, file_path=FILE_PATH, train=True)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=False,
        num_workers=NUM_WORKERS,
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.val_batch_size, shuffle=False, pin_memory=False, num_workers=NUM_WORKERS
    )

    n_parameters = count_parameters(model)
    print(f"Trainable parameters: {n_parameters}")

    model = ImageClassifier(config.selected_model, NUM_CLASSES).to(device)
    scaler = GradScaler()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    # Loss function
    if Config.loss_fn == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss().to(device)
    elif Config.loss_fn == 'SmoothCrossEntropyLoss':
        criterion = SmoothCrossEntropyLoss().to(device)
    elif Config.loss_fn == 'SymmetricCrossEntropy':
        criterion = SymmetricCrossEntropy().to(device)
    elif Config.loss_fn == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss(reduction="mean").to(device)
    elif Config.loss_fn == 'TruncatedLoss':
        criterion = TruncatedLoss(trainset_size=len(train_dataset)).to(device)
    elif Config.loss_fn == 'MAE':
        criterion = nn.L1Loss(reduction='mean').to(device)
    elif Config.loss_fn == 'TaylorCrossEntropyLoss':
        criterion = TaylorCrossEntropyLoss(smoothing=Config.label_smoothing).to(device)
    elif Config.loss_fn == 'BiTemperedLoss':
        criterion = BiTemperedLogisticLoss(t1=Config.t1, t2=Config.t2, smoothing=Config.label_smoothing).to(device)
    elif Config.loss_fn == 'FocalCosineLoss':
        criterion = FocalCosineLoss().to(device)

    # Schedulers
    num_warmup_steps = int(config.warmup_prop * config.epochs * len(train_loader))
    num_training_steps = int(config.epochs * len(train_loader))

    if Config.scheduler == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs, T_mult=1, eta_min=1e-6,
                                                                         last_epoch=-1)
    elif Config.scheduler == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=1,
            factor=0.25,
            min_lr=1e-6,
            verbose=True,
            mode="max"
        )
    # Reduce LR every step size epochs by gamma
    elif Config.scheduler == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    else:
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    val_scores = []
    val_svm_scores = []

    """
    # Main training loop over epochs
    """
    print(f"Fold: {fold + 1}")
    best_epoch_score = 0
    save_filename = ''
    for epoch in range(epochs):
        train_one_epoch(epoch, model, scaler, criterion, optimizer, train_loader, device, scheduler=scheduler,
                        schd_batch_update=False)

        with torch.no_grad():
            acc, svm_acc = valid_one_epoch(epoch, model, criterion, val_loader, device, scheduler=None,
                                           schd_loss_update=False)
            val_scores.append(acc)
            val_svm_scores.append(svm_acc)

        # Save model if this is the best epoch so far
        if acc > best_epoch_score:
            best_epoch_score = acc
            if config.save:
                # Delete previous checkpoint if exists
                if os.path.exists(f'{SAVE_MODEL_FOLDER}/{save_filename}'):
                    os.system(f'rm {SAVE_MODEL_FOLDER}/{save_filename}')
                acc_str = str(acc).replace('.','_')
                save_filename = f'{config.selected_model}_{config.name}_fold{fold}_epoch{epoch}_{acc_str}.pt'
                save_model_weights(
                    model,
                    save_filename,
                    cp_folder=SAVE_MODEL_FOLDER,
                )


    # Save model per Fold
    if config.save:
        save_model_weights(
            model,
            f"{config.selected_model}_{config.name}_fold{fold}.pt",
            cp_folder=SAVE_MODEL_FOLDER,
        )
        save_model_params(val_scores, fold, SAVE_MODEL_FOLDER)

        if Config.use_svm:
            save_model_params(val_svm_scores, fold, SAVE_MODEL_FOLDER, use_svm=True)
            pickle.dump(svm_clf, open(f'{SAVE_MODEL_FOLDER}/SVM_model.pkl', 'wb'))

    # torch.save(model.cnn_model.state_dict(),'{}/cnn_model_fold_{}_{}'.format(CFG['model_path'], fold, CFG['tag']))
    del model, optimizer, train_loader, val_loader, scaler, scheduler
    torch.cuda.empty_cache()


def k_fold(config, df):
    skf = StratifiedKFold(n_splits=config.k, random_state=config.random_state)
    splits = list(skf.split(X=df, y=df[TARGET_COL]))

    # Write all config values to save model dir for replication
    params_file = os.path.join(SAVE_MODEL_FOLDER, 'params.txt')
    with open(params_file, 'w') as f:
        f.write('Arguments:\n')
        attrs = vars(args)
        f.write(', '.join("%s: %s" % item for item in attrs.items()))
        f.write('\n\n')
        f.write('Config:\n')
        attrs = vars(Config)
        f.write(', '.join("%s: %s" % item for item in attrs.items()))
        f.write('\n\n')
        f.write(f'Train transforms: {get_train_transforms()}\n\n')
        f.write(f'Valid transforms: {get_valid_transforms()}\n\n')

    for i, (train_idx, val_idx) in enumerate(splits):
        print(f"\n-------------   Fold {i + 1} / {config.k}  -------------\n")

        df_train = df.iloc[train_idx].copy()
        df_val = df.iloc[val_idx].copy()

        train_model(config, df_train, df_val, i)


def get_scores(folds=[], base_model_path=''):
    test_df = pd.read_csv(DATA_PATH + '/sample_submission.csv')

    # local
    # model = ImageClassifier(Config.selected_model, NUM_CLASSES).to(device)

    # kaggle- no internet.
    # base_model = get_model(Config.selected_model, num_classes=NUM_CLASSES).to(device)
    # torch.save(base_model, 'effnetb4')
    base_model = torch.load(base_model_path)
    model = ImageClassifier(base_model, NUM_CLASSES, infer=True).to(device)

    avg_preds = []

    if Config.tta:
        for tta in range(Config.tta):
            test_dataset = CustomDataset(df=test_df, file_path=TEST_FILE_PATH, train=False,
                                         transforms=get_test_transforms())

            for fold in folds:
                model = load_model_weights(model, fold, verbose=1, cp_folder=SAVE_MODEL_FOLDER)
                preds = predict(model, test_dataset, infer=True)
                avg_preds.append(preds)
                # avg_preds.append(np.argmax(preds, axis=1))
                # preds = np.argmax(preds, axis=1)

        avg_preds = np.mean(avg_preds, axis=0)
    else:
        test_dataset = CustomDataset(df=test_df, file_path=TEST_FILE_PATH, train=False,
                                     transforms=get_test_transforms())

        for fold in folds:
            model = load_model_weights(model, fold, verbose=1, cp_folder=SAVE_MODEL_FOLDER)
            preds = predict(model, test_dataset, infer=True)
            avg_preds.append(np.argmax(preds, axis=1))
            # preds = np.argmax(preds, axis=1)

        avg_preds = np.mean(avg_preds, axis=0)

    return avg_preds


def get_train_transforms():
    return Compose([
        #Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1.0),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        RandomResizedCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        # RandomCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        #CoarseDropout(p=0.5),
        RandomRotate90(p=0.5),
        #Transpose(p=0.5),
        #ShiftScaleRotate(p=0.5),
        #HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        #RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)


def get_valid_transforms():
    return Compose([
        #Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1.0),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        # CenterCrop(Config.img_size, Config.img_size, p=1.),
        #RandomResizedCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        # RandomCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        # RandomRotate90(p=0.5),
        # HorizontalFlip(p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)


def get_test_transforms():
    return Compose([
        #Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1.0),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        RandomResizedCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        #RandomCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        # RandomRotate90(p=0.5),
        # HorizontalFlip(p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)


def get_sample_transforms():
    return Compose([
        # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
        RandomResizedCrop(Config.aug_img_size, Config.aug_img_size, p=1),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        # RandomRotate90(p=0.5),
        # Cutout(p=0.5),
        # Transpose(p=0.5),
        # ShiftScaleRotate(p=0.5),
        # HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        # RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        CoarseDropout(p=0.5),
        ToTensorV2(p=1.0),
    ], p=1.)


# Test the data loader and augments
def test_loader(df, samples):
    train_df = df.sample(n=samples)
    dataset = CustomDataset(df=train_df, file_path=FILE_PATH, train=False, transforms=get_sample_transforms())
    data_loader = DataLoader(
        dataset,
        batch_size=samples,
        shuffle=True,
        drop_last=False,
        pin_memory=False,
        num_workers=NUM_WORKERS,
    )
    for images, orig_images in (data_loader):
        for i in range(len(images)):
            plt.figure(figsize=(20, 16))
            img = images[i]
            orig_image = orig_images[i]
            img = img.permute(2, 1, 0)
            plt.subplot(1, 2, 1)
            plt.imshow(orig_image)
            plt.subplot(1, 2, 2)
            plt.imshow(img)
            plt.show()


def fit_noisy(
        model,
        train_dataset,
        epochs=50,
        batch_size=32,
        val_batch_size=32,
        warmup_prop=0.1,
        lr=1e-3,
        verbose=1,
        verbose_eval=1
):
    """
    Usual torch fit function

    Arguments:
        model {torch model} -- Model to train
        train_dataset {torch dataset} -- Dataset to train with
        val_dataset {torch dataset} -- Dataset to validate with

    Keyword Arguments:
        epochs {int} -- Number of epochs (default: {50})
        batch_size {int} -- Training batch size (default: {32})
        val_bs {int} -- Validation batch size (default: {32})
        warmup_prop {float} -- Warmup proportion (default: {0.1})
        lr {float} -- Start (or maximum) learning rate (default: {1e-3})
        alpha {float} -- alpha value for mixup (default: {0.4})
        mixup_proba {float} -- Probability to apply mixup (default: {0.})
        verbose {int} -- Period (in epochs) to display logs at (default: {1})
        verbose_eval {int} -- Period (in epochs) to perform evaluation at (default: {1})

    Returns:
        numpy array -- Predictions at the last epoch
    """

    # model.set_callbacks(callbacks)

    optimizer = Adam(model.parameters(), lr=lr)

    # Loss function
    if Config.loss_fn == 'CrossEntropyLoss':
        criterion = nn.CrossEntropyLoss().to(device)
    elif Config.loss_fn == 'SmoothCrossEntropyLoss':
        criterion = SmoothCrossEntropyLoss().to(device)
    elif Config.loss_fn == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss(reduction="mean").to(device)
    elif Config.loss_fn == 'TruncatedLoss':
        criterion = TruncatedLoss(trainset_size=len(train_dataset)).to(device)
    elif Config.loss_fn == 'MAE':
        criterion = nn.L1Loss(reduction='mean').to(device)
    elif Config.loss_fn == 'TaylorCrossEntropyLoss':
        criterion = TaylorCrossEntropyLoss(smoothing=Config.label_smoothing).to(device)
    elif Config.loss_fn == 'BiTemperedLoss':
        criterion = BiTemperedLogisticLoss(t1=Config.t1, t2=Config.t2, smoothing=Config.label_smoothing).to(device)
    elif Config.loss_fn == 'FocalCosineLoss':
        criterion = FocalCosineLoss().to(device)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=False,
        num_workers=NUM_WORKERS,
    )

    num_warmup_steps = int(warmup_prop * epochs * len(train_loader))
    num_training_steps = int(epochs * len(train_loader))

    if Config.scheduler == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs, T_mult=1, eta_min=1e-6,
                                                                         last_epoch=-1)
    elif Config.scheduler == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=1,
            factor=0.25,
            min_lr=1e-6,
            verbose=True,
            mode="max"
        )
    # Reduce LR every step size epochs by gamma
    elif Config.scheduler == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
    else:
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    for epoch in range(epochs):
        model.train()
        start_time = time.time()
        optimizer.zero_grad()

        avg_loss = 0
        for step, (x, y_batch) in enumerate(train_loader):
            y_pred = model(x.to(device))
            y_pred_labels = torch.max(y_pred, 1)[1]
            y_true = y_batch.to(device).long()
            y_true_labels = torch.max(y_true, 1)[1]
            if Config.loss_fn == 'CrossEntropyLoss':
                loss = criterion(y_pred, y_true_labels)
            elif Config.loss_fn == 'BiTemperedLoss':
                loss = criterion(y_pred, y_true_labels)
            else:
                loss = criterion(y_pred, y_true_labels)

            loss.backward()
            avg_loss += loss.item() / len(train_loader)
            optimizer.step()
            # Zero out gradients for next loop
            optimizer.zero_grad()

        # Update learning rate
        if Config.scheduler == 'CosineAnnealingLR':
            scheduler.step(epoch)
        else:
            scheduler.step()

        # Calculate learning rate
        lr = scheduler.get_last_lr()[0]

        elapsed_time = time.time() - start_time
        if (epoch + 1) % verbose == 0:
            elapsed_time = elapsed_time * verbose
            print(
                f"Epoch {epoch + 1}/{epochs} \t lr={lr:.1e} \t t={elapsed_time:.0f}s  \t loss={avg_loss:.4f}\n",
                end="",
            )

    torch.cuda.empty_cache()


def get_scores_train(folds=[], base_model=''):
    train_df_path = os.path.join(DATA_PATH, TRAIN_DF_FILE)
    test_df = pd.read_csv(train_df_path)

    # local
    model = get_model(Config.selected_model, num_classes=NUM_CLASSES).to(device)
    avg_preds = []

    if Config.tta:
        for tta in range(Config.tta):
            test_dataset = CustomDataset(df=test_df, file_path=FILE_PATH, train=False,
                                         transforms=get_test_transforms())

            for fold in folds:
                model = load_model_weights(model, fold, verbose=1, cp_folder=base_model)
                preds = predict(model, test_dataset, infer=True)
                avg_preds.append(preds)
                # avg_preds.append(np.argmax(preds, axis=1))
                # preds = np.argmax(preds, axis=1)

        avg_preds = np.mean(avg_preds, axis=0)
    else:
        test_dataset = CustomDataset(df=test_df, file_path=FILE_PATH, train=False,
                                     transforms=get_test_transforms())

        for fold in folds:
            model = load_model_weights(model, fold, verbose=1, cp_folder=base_model)
            preds = predict(model, test_dataset, infer=True)
            avg_preds.append(preds)
            # preds = np.argmax(preds, axis=1)

        avg_preds = np.mean(avg_preds, axis=0)

    preds_df = pd.DataFrame()
    preds_df[ID_COL] = test_df[ID_COL]
    preds_df[TARGET_COL] = test_df[TARGET_COL]
    preds_df['pred'] = avg_preds.tolist()

    return preds_df, avg_preds


CV_FOLD = 0


class Classifier(BaseEstimator):
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        self.best_model = None

    def fit(self, img_idx, labels):
        global CV_FOLD
        print(f"CV Fold: {CV_FOLD + 1}")
        CV_FOLD += 1
        df_train = train_df.iloc[img_idx].copy()
        df_train = df_train.reset_index().drop('index', axis=1)
        train_dataset = CustomDataset(df=df_train, file_path=FILE_PATH, train=True, transforms=get_train_transforms())

        fit_noisy(
            self.model,
            train_dataset,
            epochs=self.config.epochs,
            batch_size=self.config.batch_size,
            val_batch_size=self.config.val_batch_size,
            lr=self.config.lr,
            warmup_prop=self.config.warmup_prop,
            verbose_eval=self.config.verbose_eval
        )

    def predict_proba(self, img_idx, phase="train"):
        df_val = train_df.iloc[img_idx].copy()
        df_val = df_val.reset_index().drop('index', axis=1)
        test_dataset = CustomDataset(df=df_val, file_path=FILE_PATH, train=False, transforms=get_test_transforms())
        prob = predict(self.model, test_dataset, infer=True)
        return prob

    def predict(self, img_idx, phase="train"):
        prob = self.predict_proba(img_idx, phase=phase)
        preds = np.argmax(prob, axis=1)
        return preds

    def score(self, img_idx, label, phase="train"):
        preds = self.predict(img_idx, phase=phase)
        return accuracy_score(label, preds)


class Config:
    selected_model = 'resnet50'
    scheduler = 'CosineAnnealingLR'
    loss_fn = 'TaylorCrossEntropyLoss'
    augment = False
    tta = 3
    # Value to resize to
    img_size = 512
    aug_img_size = 512
    # Hyper params
    batch_size = 16
    val_batch_size = 16
    epochs = 10
    lr = 1e-4
    weight_decay = 1e-6

    # General
    seed = 1234
    # Iterations to show loss
    verbose = 1
    # Iterations to accumulate grads
    accum_iter = 4
    # verbose_eval = 31
    save = True
    # k-fold
    k = 5
    random_state = None
    # Model
    use_msd = False
    use_conf = False

    warmup_prop = 0.05
    # For bitempered loss
    t1 = 0.8
    t2 = 1.4
    # Label smoothing
    label_smoothing = 0.2
    alpha = 5
    name = "extra"
    cv_mask = False
    test_loader = False
    cutmix = False
    use_svm = False


train_df_path = os.path.join(DATA_PATH, TRAIN_DF_FILE)
train_df = pd.read_csv(train_df_path)


## Command line arguments

In [None]:

parser = argparse.ArgumentParser(description='Process pytorch params.')
parser.add_argument('-model', '--model', type=str, help='Pytorch (timm) model name')
parser.add_argument('-model_dir', '--model_dir', type=str, help='Model save dir name')
parser.add_argument('--folds', type=int, help='Number of folds')
parser.add_argument('-epochs', '--epochs', type=int, help='Number of epochs')
parser.add_argument('-img_size', '--img_size', type=int, help='Image size to resize')
parser.add_argument('-batch_size', '--batch_size', type=int, help='batch size')
parser.add_argument('-lr', '--lr', type=float, help='Learning rate')
parser.add_argument('--augment', action='store_true', help='Augment data')
parser.add_argument('--loss_fn', type=str, help='Loss function')
parser.add_argument('--scheduler', type=str, help='Scheduler')
parser.add_argument('--n_samples', type=int, help='Number of samples of train data')
parser.add_argument('--train', action='store_true', help='Pytorch train mode')
parser.add_argument('--test', action='store_true', help='Pytorch test mode')
parser.add_argument('--find_model', type=str, help='Regex for timm model search')
parser.add_argument('--cv_mask', action='store_true', help='Pytorch train mode')
parser.add_argument('--stats', action='store_true', help='Pytorch train mode')
parser.add_argument('--test_loader', action='store_true', help='Pytorch train mode')
parser.add_argument('--downsample', action='store_true', help='Downsample')
parser.add_argument('--upsample', type=float, help='Upsample pct e.g value of 0.5 is 50% of majority class')
parser.add_argument('--get_noise_indices', action='store_true', help='Get noise indices from cleanlab')
parser.add_argument('--train_noisy_clf', action='store_true', help='Train a noisy classifier (Cleanlab)')
parser.add_argument('--clean_data', type=str, help='Remove user specified indices from train')
parser.add_argument('--clean_samples', type=int, help='Remove user specified indices from train')
parser.add_argument('--train_df', type=str, help='Remove user specified indices from train')
parser.add_argument('--cutmix', action='store_true', help='Cutmix for images')
parser.add_argument('--use_svm', action='store_true', help='Use SVM on top of model')

args = parser.parse_args()

# User defined train df
if args.train_df:
    train_df = pd.read_csv(args.train_df)

NUM_CLASSES = train_df[TARGET_COL].nunique()
print(f"Number of classes: {NUM_CLASSES}")

# %% [code]
ONE_HOT = np.eye(NUM_CLASSES)
le = LabelEncoder()
train_df[TARGET_COL + '_encoded'] = le.fit_transform(train_df[TARGET_COL])

# Parse arguments
Config.selected_model = args.model
Config.cv_mask = args.cv_mask
Config.cutmix = args.cutmix

if args.img_size:
    Config.img_size = args.img_size
if args.batch_size:
    Config.batch_size = args.batch_size
if args.folds:
    Config.k = args.folds
if args.epochs:
    Config.epochs = args.epochs
if args.augment:
    Config.augment = args.augment
if args.loss_fn:
    Config.loss_fn = args.loss_fn
if args.scheduler:
    Config.scheduler = args.scheduler
if args.lr:
    Config.lr = args.lr
if args.use_svm:
    Config.use_svm = args.use_svm
    svm_clf = svm.LinearSVC(C=1, verbose=0, max_iter=100000, loss='squared_hinge', penalty='l2', dual=True)

print(f'Device:{device}')
print(f'Arguments: {args}\n')

if args.stats:
    print(f'Train shape: {train_df.shape}')
    print(f'Target dist:\n{train_df[TARGET_COL].value_counts()}')

if args.find_model:
    model_names = timm.list_models(f'*{args.find_model}*')
    print(model_names)

# Try a subset of data
if args.n_samples:
    train_df = train_df.sample(n=args.n_samples)

# FIXME - move this to inside fold/epoch loop
if args.downsample:
    majority_class = 3
    train_df_min = train_df[train_df[TARGET_COL] != majority_class]
    samples = int(np.ceil(train_df_min[TARGET_COL].value_counts().max()))
    train_df_maj = train_df[train_df[TARGET_COL] == majority_class]
    df_majority_downsampled = resample(train_df_maj,
                                       replace=False,  # sample without replacement
                                       n_samples=samples,  # to match minority class
                                       random_state=123)
    train_df = pd.concat([train_df_min, df_majority_downsampled])
    # print(train_df[TARGET_COL].value_counts())

if args.model_dir:
    model_dir = args.model_dir
    SAVE_MODEL_FOLDER = BASE_MODEL_FOLDER + model_dir
    if not os.path.exists(SAVE_MODEL_FOLDER):
        os.makedirs(SAVE_MODEL_FOLDER)

if args.train:
    k_fold(Config, train_df)

if args.test_loader:
    Config.test_loader = True
    test_loader(train_df, 4)

if args.test:
    SAVE_MODEL_FOLDER = '../input/image-pretrained-models/'
    test_df = pd.read_csv(DATA_PATH + '/sample_submission.csv')
    ensemble = False
    if ensemble:
        preds1 = get_scores(['tf_efficientnet_b4_ns_extra_2.pt'], '../input/base-pretrained/effnetb4')
        preds2 = get_scores(['seresnet50_extra_2.pt'], '../input/base-pretrained/seresnet50')
        preds = 0.5 * preds1 + 0.5 * preds2
    else:
        # preds = get_scores(['tf_efficientnet_b4_ns_extra_2_512_1_4.pt'], '../input/base-pretrained/effnetb4')
        preds = get_scores(['seresnet50_extra_0.pt'], '../input/base-pretrained/seresnet50')

    submission_df = pd.DataFrame()
    submission_df[ID_COL] = test_df[ID_COL]
    submission_df[TARGET_COL] = softmax(preds).argmax(1)
    submission_df[TARGET_COL] = submission_df[TARGET_COL].astype(int)

    print(submission_df.head())
    submission_df.to_csv('submission.csv', index=False)

if args.get_noise_indices:
    # Prediction
    preds_df, psx = get_scores_train(['./models/seresnet50_1_3_test/seresnet50_extra_2.pt'])

    labels = preds_df[TARGET_COL].values

    ordered_label_errors = get_noise_indices(
        s=labels,
        psx=psx,
        sorted_index_method='normalized_margin',  # Orders label errors
    )

    print(f'Label errors found={len(ordered_label_errors)}')

    errors_df = preds_df.iloc[ordered_label_errors, :].reset_index()

    print(errors_df.shape)
    print(errors_df.head())
    errors_df.to_csv('errors.csv', index=False)

if args.train_noisy_clf:
    train_image_id = train_df[ID_COL].values
    train_label = train_df[TARGET_COL].values
    val_preds = np.zeros((train_label.shape[0], 5))
    kfold = StratifiedKFold(n_splits=Config.k, random_state=None)
    seed_everything(0)

    val_scores = []

    for fold, (train_idx, val_idx) in enumerate(kfold.split(train_image_id, train_label)):
        CV_FOLD = 0
        print(f'Fold: {fold + 1}')
        X_train, y_train = train_image_id[train_idx], train_label[train_idx]
        X_val, y_val = train_image_id[val_idx], train_label[val_idx]

        base_model = get_model(Config.selected_model, num_classes=NUM_CLASSES).to(device)
        model = Classifier(base_model, Config)
        lnl = LearningWithNoisyLabels(clf=model, seed=0, n_jobs=os.cpu_count(), cv_n_folds=5)
        clf = lnl.fit(train_idx, y_train)

        val_preds[val_idx, :] = clf.predict_proba(val_idx)
        acc = accuracy_score(y_val, np.argmax(val_preds[val_idx, :], axis=1))
        val_scores.append(acc)
        print(f'Accuracy:{acc}')
        # Save model per Fold
        if Config.save:
            save_model_weights(
                clf.model,
                f"{Config.selected_model}_{Config.name}_{fold}.pt",
                cp_folder=SAVE_MODEL_FOLDER,
            )
            save_model_params(val_scores, fold, SAVE_MODEL_FOLDER)

if args.clean_data:
    errors_df = pd.read_csv(args.clean_data)
    errors_df = errors_df.iloc[0:args.clean_samples - 1, :]
    error_indices = errors_df['index'].values
    print(f"Input shape: {train_df.shape}")
    train_df = train_df.loc[~train_df.index.isin(error_indices)]
    print(f'Removing {args.clean_samples} samples from the dataset')
    print(f"Output shape: {train_df.shape}")
    train_df.to_csv('mod_train_df.csv', index=False)
