## Config

In [None]:
batch_size = 1
image_size = 512
enet_type = ['tf_efficientnet_b4_ns'] * 5
model_path = ['../input/moa-b4-baseline/baseline_cld_fold0_epoch8_tf_efficientnet_b4_ns_512.pth', 
              '../input/moa-b4-baseline/baseline_cld_fold1_epoch9_tf_efficientnet_b4_ns_512.pth', 
              '../input/moa-b4-baseline/baseline_cld_fold2_epoch9_tf_efficientnet_b4_ns_512.pth',
              '../input/moa-b4-baseline/baseline_cld_fold3_epoch5_tf_efficientnet_b4_ns_512.pth',
              '../input/moa-b4-baseline/baseline_cld_fold4_epoch11_tf_efficientnet_b4_ns_512.pth']

## Imports

In [None]:
import os
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')
import pandas as pd
import numpy as np
DEBUG = False
import time
import cv2
import PIL.Image
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import albumentations
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from pylab import rcParams
import timm
device = torch.device('cuda') if not DEBUG else torch.device('cpu')

## Aug

In [None]:
transforms_valid = albumentations.Compose([
    albumentations.CenterCrop(image_size, image_size, p=1),
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize()
])

## Dataset

In [None]:
class CLDDataset(Dataset):
    def __init__(self, df, mode, transform=None):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.loc[index]
        image = cv2.imread(row.filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            res = self.transform(image=image)
            image = res['image']
        
        image = image.astype(np.float32)
        image = image.transpose(2,0,1)
        if self.mode == 'test':
            return torch.tensor(image).float()
        else:
            return torch.tensor(image).float(), torch.tensor(row.label).float()

test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
test['filepath'] = test.image_id.apply(lambda x: os.path.join('../input/cassava-leaf-disease-classification/test_images', f'{x}'))
test_dataset = CLDDataset(test, 'test', transform=transforms_valid)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,  num_workers=4)

## Model

In [None]:
class enet_v2(nn.Module):

    def __init__(self, backbone, out_dim, pretrained=False):
        super(enet_v2, self).__init__()
        self.enet = timm.create_model(backbone, pretrained=pretrained)
        in_ch = self.enet.classifier.in_features
        self.myfc = nn.Linear(in_ch, out_dim)
        self.enet.classifier = nn.Identity()

    def forward(self, x):
        x = self.enet(x)
        x = self.myfc(x)
        return x


## Inference(with 8xTTA)

In [None]:
def inference_func(test_loader):
    model.eval()
    bar = tqdm(test_loader)

    LOGITS = []
    PREDS = []
    
    with torch.no_grad():
        for batch_idx, images in enumerate(bar):
            x = images.to(device)
            logits = model(x)
            LOGITS.append(logits.cpu())
            PREDS += [torch.softmax(logits, 1).detach().cpu()]
        PREDS = torch.cat(PREDS).cpu().numpy()
        LOGITS = torch.cat(LOGITS).cpu().numpy()
    return PREDS

def tta_inference_func(test_loader):
    model.eval()
    bar = tqdm(test_loader)
    PREDS = []
    LOGITS = []

    with torch.no_grad():
        for batch_idx, images in enumerate(bar):
            x = images.to(device)
            x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),
            x.transpose(-1,-2),x.transpose(-1,-2).flip(-1),
            x.transpose(-1,-2).flip(-2),x.transpose(-1,-2).flip(-1,-2)],0)
            x = x.view(-1, 3, image_size, image_size)
            logits = model(x)
            logits = logits.view(batch_size, 8, -1).mean(1)
            PREDS += [torch.softmax(logits, 1).detach().cpu()]
            LOGITS.append(logits.cpu())

        PREDS = torch.cat(PREDS).cpu().numpy()
        
    return PREDS

In [None]:
test_preds = []
for i in range(len(enet_type)):
    model = enet_v2(enet_type[i], out_dim=5)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path[i]))
    test_preds += [tta_inference_func(test_loader)]

## Submit

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.label = np.argmax(np.mean(test_preds, axis=0), axis=1)
submission.to_csv('submission.csv', index=False)

weights are available at: https://www.kaggle.com/underwearfitting/moa-b4-baseline