This kernel is heavily inspired by brilliant notebooks of [@abhishek](https://www.kaggle.com/abhishek) and [@cdeotte](https://www.kaggle.com/cdeotte) in [SIIM-ISIC Melanoma Classification](https://www.kaggle.com/c/siim-isic-melanoma-classification). Here I am re-using many ideas from SIIMISIC competition as well as some new ideas from [
Using Tez in Leaf Disease Classification](https://www.kaggle.com/abhishek/using-tez-in-leaf-disease-classification) by @abhishek .

For the sake of time, you can disable OOF TTA to make overall processing significantly faster. I just monitor OOF TTA to double-check how well applied TTA methods perform.

Inference notebook is [Pytorch EfficientNet with TTA [inference]](https://www.kaggle.com/dunklerwald/pytorch-efficientnet-with-tta-inference).

**UPDATE**:
- switched to noisy-student
- added simple upsampling option (disabled by default)
- added confusion matrix

**UPDATE1**:
- switched to CosineAnnealingWarmRestarts
- upgraded to image size 512
- switched back to B4 effnet

**UPDATE(FINAL)**:
- added weight decay
- switched to smoothed cross entropy loss

In [None]:
package_path = '../input/pytorch-image-models/pytorch-image-models-master'
import sys
sys.path.append(package_path)    

In [None]:
import os
import torch
import albumentations

import numpy as np
import pandas as pd
import warnings

import time
import datetime

import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff 

import torch.nn as nn
from sklearn import metrics, model_selection
from torch.nn import functional as F
from torch.nn.modules.loss import _WeightedLoss

from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True


import timm

warnings.simplefilter('ignore')
%matplotlib inline

## Parameters

In [None]:
n_epochs = 10
n_patience = 5
n_folds = 3
train_bsize = 24
valid_bsize = 48
test_bsize = 48
seed = 42

effnet_output = {0: 1280, 1: 1280, 2: 1408, 3: 1536, 4: 1792, 5: 2048, 6: 2304, 7: 2560}

IMG_SIZE = 512
EFFNET_MODEL = 4

AUGMENTATION =[albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=15, border_mode=0, p=0.6),
              albumentations.Flip(p=0.5),
              albumentations.RandomRotate90(p=0.5),
              albumentations.RandomBrightness(limit=0.2, p=0.6),
              albumentations.RandomContrast(limit=0.2, p=0.6),
              albumentations.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=20, p=0.6),
              albumentations.CoarseDropout(max_holes=8, max_height=int(IMG_SIZE*0.2), max_width=int(IMG_SIZE*0.2), p=0.6),
              albumentations.Cutout(num_holes=1, max_h_size=int(IMG_SIZE*0.33), max_w_size=int(IMG_SIZE*0.33), p=0.6)
              ]   

IS_TTA = True
TTA = 3 

UPSAMPLE = False
N_UPSAMPLE = 1

SCHEDULER_NAME = 'CosineAnnealingLR' # ReduceLROnPlateau, CosineAnnealingLR, CustomSchedulerLR
LOSS_FN_NAME = 'SmoothCrossEntropyLoss' # SmoothCrossEntropyLoss,CrossEntropyLoss 

SMOOTHING = 0.05


DISPLAY_PLOT= True

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(seed)

In [None]:
path = '../input/cassava-leaf-disease-classification/'

In [None]:
# create folds
df = pd.read_csv(path + 'train.csv')
df["kfold"] = -1    
df = df.sample(frac=1).reset_index(drop=True)


kf = model_selection.StratifiedKFold(n_splits=n_folds, random_state = seed, shuffle=True)

for f, (t_, v_) in enumerate(kf.split(np.arange(df.shape[0]), df.label.values)):
    df.loc[df.index.isin(v_), ['kfold']] = f
    
df.to_csv("train_folds.csv", index=False)

N_CLASSES = df.label.nunique()
LABELS = ['Cassava Bacterial Blight','Cassava Brown Streak Disease','Cassava Green Mottle','Cassava Mosaic Disease','Healthy']

## Plotting

In [None]:
def plot_fold(history, title):

    fig = make_subplots(specs=[[{"secondary_y": True}]])

    
    trace = (        
          px.line(history, x=history.index+1, y='train_loss') 
         .add_trace(px.line(history, x=history.index+1, y='val_loss').data[0]) 
         .add_trace(px.line(history, x=history.index+1, y='train_accuracy').data[0])
         .add_trace(px.line(history, x=history.index+1, y='val_accuracy').data[0]) 
         
        ).data

    
    fig.add_trace(trace[0], secondary_y=False) 
    fig.add_trace(trace[1], secondary_y=False) 
    fig.add_trace(trace[2], secondary_y=True)
    fig.add_trace(trace[3], secondary_y=True)

    
    fig.data[0].line.dash='dash';fig.data[0].mode ='markers+lines';fig.data[0].line.color='#2ca02c';fig.data[0].line.width=3;fig.data[0].hovertemplate=None;fig.data[0].name='train loss' 
    fig.data[1].line.dash='dash';fig.data[1].mode ='markers+lines';fig.data[1].line.color='#d62728';fig.data[1].line.width=3;fig.data[1].hovertemplate=None;fig.data[1].name='val loss'
    fig.data[2].line.dash='dashdot';fig.data[2].mode ='markers+lines';fig.data[2].line.color='#ff7f0e';fig.data[2].line.width=3;fig.data[2].hovertemplate=None;fig.data[2].name='train accuracy'
    fig.data[3].line.dash='dashdot';fig.data[3].mode ='markers+lines';fig.data[3].line.color='#1f77b4';fig.data[3].line.width=3;fig.data[3].hovertemplate=None;fig.data[3].name='val accuracy'
    
    
    # Set x-axis title
    fig.update_xaxes(title_text="Epoch")

    # Set y-axes titles
    fig.update_yaxes(title_text="Loss", secondary_y=False)    
    fig.update_yaxes(title_text="Accuracy", secondary_y=True)
    fig.update_layout(height=450, margin=dict(r=5, t=50, b=50, l=5), title_text='<b>'+title+'</b>', title_font_size=12, legend=dict(orientation='h',yanchor='top',y=1.03,xanchor='left',x=0.15))
    fig.update_layout(font_size=12)
    fig.for_each_annotation(lambda a: a.update(font=dict(size=14)))
    fig.update_layout(hovermode="x unified")
    fig.update_traces(showlegend=True)
    
    fig.show()
    

    
def plot_confusion_matrix(label, pred):
    c_matrix = metrics.confusion_matrix(label, pred, labels=range(len(LABELS)), normalize='true')
    df = pd.DataFrame(c_matrix, index=LABELS, columns=LABELS)
    df_text = np.around(df.values, decimals=2)

    fig = ff.create_annotated_heatmap(df.values, annotation_text=df_text, x=LABELS, y=LABELS, colorscale='PuBu' )
    fig.update_layout(font_size=9, height=450, margin=dict(r=5, t=50, b=50, l=5)) 
    
    fig.show()       

## Model

In [None]:
class ClassificationDataset:
    def __init__(self, image_paths, targets, resize, augmentations=None):
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        image = Image.open(self.image_paths[item])
        targets = self.targets[item]
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )
        image = np.array(image)
        if self.augmentations is not None:
            augmented = self.augmentations(image=image)
            image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image),
            "targets": torch.tensor(targets),
        }


class ClassificationDataLoader:
    def __init__(self, image_paths, targets, resize, augmentations=None):
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations
        self.dataset = ClassificationDataset(
            image_paths=self.image_paths,
            targets=self.targets,
            resize=self.resize,
            augmentations=self.augmentations
        )
    
    def fetch(self, batch_size, num_workers, drop_last=False, shuffle=True, tpu=False):
        sampler = None

        data_loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=batch_size,
            sampler=sampler,
            drop_last=drop_last,
            shuffle=shuffle,
            num_workers=num_workers
        )
        return data_loader

In [None]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, mode="max", delta=0.0001, tpu=False):
        self.patience = patience
        self.counter = 0
        self.mode = mode
        self.best_score = None
        self.early_stop = False
        self.tpu = tpu
        self.delta = delta
        if self.mode == "min":
            self.val_score = np.Inf
        else:
            self.val_score = -np.Inf

    def __call__(self, epoch_score, model, model_path):
        if self.mode == "min":
            score = -1.0 * epoch_score
        else:
            score = np.copy(epoch_score)

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(
                "EarlyStopping counter: {} out of {}".format(
                    self.counter, self.patience
                )
            )
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(epoch_score, model, model_path)
            self.counter = 0

    def save_checkpoint(self, epoch_score, model, model_path):
        if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
            if self.tpu:
                xm.master_print(
                    "Validation score improved ({} --> {}). Saving model!".format(
                        self.val_score, epoch_score
                    )
                )
            else:
                print(
                    "Validation score improved ({} --> {}). Saving model!".format(
                        self.val_score, epoch_score
                    )
                )
            if self.tpu:
                xm.save(model.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
        self.val_score = epoch_score

In [None]:
class Engine:
    @staticmethod
    def train(
        data_loader,
        model,
        optimizer,
        device,
        scheduler=None,
        accumulation_steps=1,
        fp16=True,
    ):

        losses = AverageMeter()
        accuracies = AverageMeter()
        final_predictions = []
        model.train()
        if accumulation_steps > 1:
            optimizer.zero_grad()

        if fp16:
          scaler = torch.cuda.amp.GradScaler()    

        for b_idx, data in enumerate(data_loader):
            for key, value in data.items():
                data[key] = value.to(device)
            if accumulation_steps == 1 and b_idx == 0:
                optimizer.zero_grad()
            if fp16:    
                with torch.cuda.amp.autocast():    
                    predictions, loss, accuracy = model(**data)
            else:
                predictions, loss, accuracy = model(**data)

            predictions = predictions.detach().cpu().numpy()  
            final_predictions.append(predictions) 

            with torch.set_grad_enabled(True):
                if fp16:
                    scaler.scale(loss).backward()                   
                else:
                    loss.backward()
                if (b_idx + 1) % accumulation_steps == 0:
                    if fp16:
                        scaler.step(optimizer)
                        scaler.update()
                    else:     
                        optimizer.step()
                    if scheduler is not None:
                         scheduler.step()
                    if b_idx > 0:
                        optimizer.zero_grad()

            losses.update(loss.item(), data_loader.batch_size)
            accuracies.update(accuracy.item(), data_loader.batch_size)

        return final_predictions, losses.avg, accuracies.avg

    @staticmethod
    def evaluate(data_loader, model, device):
        losses = AverageMeter()
        accuracies = AverageMeter()
        final_predictions = []
        model.eval()
        with torch.no_grad():
            for b_idx, data in enumerate(data_loader):    
                for key, value in data.items():
                    data[key] = value.to(device)
                predictions, loss, accuracy = model(**data)
                predictions = predictions.detach().cpu().numpy()  
                final_predictions.append(predictions) 
                
                losses.update(loss.item(), data_loader.batch_size)    
                accuracies.update(accuracy.item(), data_loader.batch_size)

        return final_predictions, losses.avg, accuracies.avg

    @staticmethod
    def predict(data_loader, model, device):
        model.eval()
        final_predictions = []

        with torch.no_grad():

            for b_idx, data in enumerate(data_loader):    
                for key, value in data.items():
                    data[key] = value.to(device)
                predictions, _, _ = model(**data)
                predictions = predictions.detach().cpu().numpy()  
                final_predictions.append(predictions) 
                   
        return final_predictions

In [None]:
class CustomSchedulerLR:
    def lrfn(self, epoch):      
        if epoch < self.lr_ramp_ep:
            lr = (self.lr_max - self.lr_start) / self.lr_ramp_ep * epoch + self.lr_start           
        elif epoch < self.lr_ramp_ep + self.lr_sus_ep:
            lr = self.lr_max
        else:
            lr = (self.lr_max - self.lr_min) * self.lr_decay**(epoch - self.lr_ramp_ep - self.lr_sus_ep) + self.lr_min

        return lr   

    def __init__(self, optimizer, epoch, batch_size):
        self.lr_start = 0.00005
        self.lr_min = 0.00005
        self.lr_ramp_ep = 5
        self.lr_sus_ep = 0
        self.lr_decay = 0.8  
        self.lr_max = 0.00001 * batch_size
        self.optimizer = optimizer

        lr = self.lrfn(epoch)    
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def step(self, epoch):
        lr = self.lrfn(epoch)  
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

In [None]:
#source: https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/173733
class SmoothCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=SMOOTHING):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    @staticmethod
    def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=SMOOTHING):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                    device=targets.device) \
                .fill_(smoothing /(n_classes-1)) \
                .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
        return targets

    def forward(self, inputs, targets):
        targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
            self.smoothing)
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

In [None]:
class EfficientNet(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNet, self).__init__()
        self.base_model = timm.create_model(f"tf_efficientnet_b{str(EFFNET_MODEL)}_ns", pretrained=True)
        self.dropout = nn.Dropout(0.2)
        
        self.out = nn.Linear(
            in_features=effnet_output[EFFNET_MODEL], 
            out_features=num_classes, 
            bias=True
        )
        
    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape
        
        x = self.base_model.forward_features(image) 
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        out = self.out(self.dropout(x))
        
        if LOSS_FN_NAME == 'CrossEntropyLoss':
            loss = nn.CrossEntropyLoss()(out, targets.long())  
        elif LOSS_FN_NAME == 'SmoothCrossEntropyLoss':
            loss = SmoothCrossEntropyLoss()(out, targets.long())
        else:
            loss = nn.CrossEntropyLoss()(out, targets.long())

        outputs = torch.argmax(out, dim=1).detach().cpu().numpy()
        targets = targets.detach().cpu().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)    
        
        return out, loss, accuracy

In [None]:
def train(fold = 0, apply_tta = False):
    print('=' * 20, 'Fold', fold, '=' * 20)
    model_path=f"model_fold_{fold}.bin"
    start_time = time.time()

    device = "cuda"
    epochs = n_epochs
    train_bs = train_bsize
    valid_bs = valid_bsize

    training_data_path = path + "train_images/"
    df = pd.read_csv("train_folds.csv")

    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)


    model = EfficientNet(num_classes=N_CLASSES)
    model.to(device)

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    train_aug = albumentations.Compose(AUGMENTATION + [albumentations.RandomResizedCrop(IMG_SIZE, IMG_SIZE, always_apply=True), albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)])
    valid_aug = albumentations.Compose([albumentations.CenterCrop(IMG_SIZE, IMG_SIZE, always_apply=True),albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)])

    train_images = df_train.image_id.values.tolist()
    train_images = [os.path.join(training_data_path, i ) for i in train_images]
    train_targets = df_train.label.values

    valid_images = df_valid.image_id.values.tolist()
    valid_images = [os.path.join(training_data_path, i ) for i in valid_images]
    valid_targets = df_valid.label.values

    # UPSAMPLE MINORITY CLASSES
    if UPSAMPLE:
        train_images_ = df_train.loc[df_train['label']!=3, ['image_id', 'label']]
        train_images_m = [os.path.join(training_data_path, i) for i in train_images_.image_id.values.tolist()]
        train_targets_m = train_images_['label']
        
        for i in range(N_UPSAMPLE):
            train_images = train_images + train_images_m
            train_targets = np.concatenate([train_targets, train_targets_m])


    # preliminary shuffle train_images and train_targets correspondingly if shuffle=False for train_loader        
    train_loader = ClassificationDataLoader(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug,
    ).fetch(
        batch_size=train_bs, 
        drop_last=True, 
        num_workers=4, 
        shuffle=True
    )

    valid_loader = ClassificationDataLoader(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug,
    ).fetch(
        batch_size=valid_bs, 
        drop_last=False, 
        num_workers=4, 
        shuffle=False
    )

    # SETUP OPTIMIZER AND SCHEDULE
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay = 1e-6)

    if SCHEDULER_NAME == 'CosineAnnealingLR': 
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=epochs, T_mult=1, eta_min=1e-6, last_epoch=-1)
    elif SCHEDULER_NAME == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
          optimizer,
          patience=1, 
          factor=0.25,
          min_lr=1e-6,
          verbose=True,
          mode="max"
        )

    else:
        scheduler = CustomSchedulerLR(optimizer=optimizer, epoch=0, batch_size=train_bs)

    
    es = EarlyStopping(patience=n_patience, mode="max")
    
    
    history = pd.DataFrame(columns=['train_loss','val_loss','train_accuracy','val_accuracy'])
    
    
    # TRAIN
    print('Training...')
    for epoch in range(epochs):
        for param_group in optimizer.param_groups:
            lr = param_group['lr']

        predictions, train_loss, train_accuracy = Engine.train(train_loader, model, optimizer, device=device)
        train_predictions = np.vstack((predictions))
        predictions, valid_loss, val_accuracy = Engine.evaluate(valid_loader, model, device=device)
        predictions = np.vstack((predictions))
        
        history = history.append({'train_loss':train_loss,'val_loss':valid_loss,'train_accuracy':train_accuracy,'val_accuracy':val_accuracy}, ignore_index=True)
        print(f"Epoch {epoch+1:03}: | train_loss: {train_loss:.4f} | val_loss {valid_loss:.4f} | train_accuracy {train_accuracy:.4f} | val_accuracy {val_accuracy:.4f} | lr {lr:.6f}")

        if SCHEDULER_NAME == 'CosineAnnealingLR':
            scheduler.step(epoch)
        elif SCHEDULER_NAME == 'ReduceLROnPlateau':
            scheduler.step(val_accuracy)
        else:
            scheduler.step(epoch+1)  

        es(val_accuracy, model, model_path=f"model_fold_{fold}.bin")
        
        if es.counter == 0: 
            best_prediction = predictions.argmax(axis=1)
        if es.early_stop:
            print("Early stopping")
            break
    
    # PREDICT OOF WITH TTA
    # you can remove this phase to speed up the process
    if apply_tta:
        print('Predicting OOF with TTA...')
        oof_tta_predictions = np.zeros([len(valid_targets),N_CLASSES])

        model = EfficientNet(num_classes=N_CLASSES)
        model.load_state_dict(torch.load(model_path))
        model.to(device)

        valid_tta_aug = albumentations.Compose(AUGMENTATION + [valid_aug])

        valid_tta_loader = ClassificationDataLoader(
            image_paths=valid_images,
            targets=valid_targets,
            resize=None,
            augmentations=valid_tta_aug,
        ).fetch(
            batch_size=valid_bs, 
            drop_last=False, 
            num_workers=4, 
            shuffle=False
        )
        
        for i in range(TTA): 
            tta_predictions = Engine.predict(valid_tta_loader, model, device=device)
            tta_predictions = np.vstack((tta_predictions))
            oof_tta_predictions += tta_predictions/TTA 
        
        oof_tta_predictions = oof_tta_predictions.argmax(axis=1)
        
        
    if apply_tta:
        title = 'Fold {}: | training time: {} | OOF Accuracy without TTA: {:.5f} | OOF Accuracy with TTA: {:.5f} '.format(fold, str(datetime.timedelta(seconds=time.time() - start_time))[:7], es.best_score, metrics.accuracy_score(valid_targets, oof_tta_predictions))
    else:
        title = 'Fold {}: | training time: {} | OOF Accuracy without TTA: {:.5f}'.format(fold, str(datetime.timedelta(seconds=time.time() - start_time))[:7], es.best_score)
    print(title)  
    

    # PLOT TRAINING
    if DISPLAY_PLOT:
        plot_fold(history, title = title)
    

    if apply_tta:
        return best_prediction, oof_tta_predictions
    else:      
        return best_prediction, _

In [None]:
def predict(fold = 0, apply_tta = False):
    print('=' * 20, 'Fold', fold, '=' * 20)
    test_data_path = path + "test_images/"
    df = test
    device = "cuda"
    model_path=f"model_fold_{fold}.bin"

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    if apply_tta:
        aug = albumentations.Compose(AUGMENTATION + [albumentations.RandomResizedCrop(IMG_SIZE, IMG_SIZE, always_apply=True), albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)])
    else:
        aug = albumentations.Compose([albumentations.CenterCrop(IMG_SIZE, IMG_SIZE, always_apply=True),albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)])
    
    images = [os.path.join(test_data_path, x) for x in df.image_id.values]
    targets = df.label.values

    test_loader = ClassificationDataLoader(
        image_paths=images,
        targets=targets,
        resize=None,
        augmentations=aug,
    ).fetch(
        batch_size=test_bsize, 
        drop_last=False, 
        num_workers=4, 
        shuffle=False
    )

    model = EfficientNet(num_classes=N_CLASSES)
    model.load_state_dict(torch.load(model_path))
    model.to(device)

    # PREDICT
    print('Predicting...')
    if apply_tta:
        predictions = np.zeros([len(images),N_CLASSES])
        
        for i in range(TTA): 
            tta_predictions = Engine.predict(test_loader, model, device=device)
            tta_predictions = np.vstack(tta_predictions)
            predictions += tta_predictions/TTA  
        predictions = predictions.reshape((len(images),1, N_CLASSES))    
    else:
        predictions = Engine.predict(test_loader, model, device=device)

    return predictions

## Training

In [None]:
oof = df[['image_id','kfold','label']].copy()
oof['pred'] = 0

if IS_TTA:
    oof_tta = oof.copy()
    for i in range(n_folds):
        oof.loc[oof['kfold']==i, 'pred'], oof_tta.loc[oof['kfold']==i, 'pred'] = train(fold = i, apply_tta = IS_TTA)    
else:
    for i in range(n_folds):
        oof.loc[oof['kfold']==i, 'pred'], _ = train(fold = i, apply_tta = IS_TTA)        

In [None]:
if IS_TTA:
    print('Overall OOF Accuracy without TTA: {:.5f} | OOF Accuracy with TTA: {:.5f}'.format(metrics.accuracy_score(oof['label'], oof['pred']), metrics.accuracy_score(oof['label'], oof_tta['pred'])))
else:
    print('Overall OOF Accuracy without TTA: {:.5f} '.format(metrics.accuracy_score(oof['label'], oof['pred'])))    

In [None]:
# plot confusion matrix
plot_confusion_matrix(oof['label'], oof['pred'])

## Predict 

In [None]:
test = pd.read_csv(path + "sample_submission.csv")

In [None]:
final_preds = None

for i in range(n_folds):
    preds = predict(fold = i, apply_tta=IS_TTA)
    temp_preds = None
    for p in preds:
        if temp_preds is None:
            temp_preds = p
        else:
            temp_preds = np.vstack((temp_preds, p))
    if final_preds is None:
        final_preds = temp_preds
    else:
        final_preds += temp_preds

final_preds /= n_folds
final_preds = final_preds.argmax(axis=1)

test.label = final_preds
test.to_csv('submission.csv', index=False)

In [None]:
test.head()