In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import torchvision
from tqdm import tqdm
import cv2
from skimage import io
import time

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2
from sklearn.metrics import accuracy_score

from sklearn.model_selection import GroupKFold, StratifiedKFold
import tqdm.notebook as tq
from sklearn.model_selection import train_test_split
from scipy.special import softmax

In [None]:

CFG = {
    'vit_img_size': 384,
    'tta': 3,
    'valid_bs': 16,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'vit_models': ['model_5.pt', 'model_6.pt', 'model_7.pt', 'model_8.pt']
}

In [None]:
class DiseaseDatasetInference(torch.utils.data.Dataset):

    def __init__ (self, df, transform=None, opt_label=True):
        self.df = df.reset_index(drop=True).copy()
        self.transform = transform
        self.opt_label = opt_label

        if self.opt_label:
            self.data = [(row['image_id'], row['label']) for _, row in self.df.iterrows()]

        else:
            self.data = [(row['image_id']) for _, row in self.df.iterrows()]

        self.data = np.asarray(self.data)
  
    def __len__(self):
        return len(self.data)

    def __getitem__ (self, index):
            # np.random.shuffle(self.data)
        if self.opt_label:
            image_path, label = self.data[index]    
        else:
            image_path = self.data[index]

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform is not None:
            image = self.transform(image=image)['image']

        if self.opt_label == True:
            return (image, int(label))

        else:
            return image

In [None]:
def get_inference_transforms(img_size = 512):
    return Compose([
            CenterCrop(img_size, img_size, p=0.5),
            Resize(img_size, img_size),
            Transpose(p=0.5),
            RandomRotate90(p=0.25),
            ShiftScaleRotate(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
            
        ], p=1.)

In [None]:
df = pd.read_csv('/kaggle/input/cassava-leaf-disease-classification/sample_submission.csv')
PATH = '/kaggle/input/cassava-leaf-disease-classification/test_images/'

In [None]:
test_csv = df.copy()
test_csv['image_id'] = PATH + test_csv['image_id']

In [None]:
def inference (model, data_loader, device):
    preds = []
    model.eval()
    test_tqdm = tq.tqdm(data_loader, total=len(data_loader), desc="Testing", position=0, leave=True)
    for images in test_tqdm:
        images = images.to(device)
        preds.extend(model(images).detach().cpu().numpy())
    return preds

In [None]:
class CassavaImageClassifier(nn.Module):
    def __init__ (self, model_arch, n_class, pretrained=False):
        super().__init__ ()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.head.in_features
        self.model.head = nn.Linear(n_features, n_class)

    def forward (self, x):
        x = self.model(x)
        return x

In [None]:
vit_test_ds = DiseaseDatasetInference(test_csv, transform=get_inference_transforms(img_size=CFG['vit_img_size']), opt_label=False)

vit_test_loader = torch.utils.data.DataLoader(vit_test_ds, batch_size=CFG['valid_bs'], shuffle=False, pin_memory=False) 

In [None]:
vit_preds = []
for vit_model_name in CFG['vit_models']:
    print("Model: ", vit_model_name)
    vit_model = torch.load('/kaggle/input/vit-cassava/'+vit_model_name, map_location=torch.device(CFG['device']))
    with torch.no_grad():
        for i in range(CFG['tta']):
            vit_preds += [inference(vit_model, vit_test_loader, CFG['device'])]
vit_preds = np.mean(vit_preds, axis=0)

In [None]:
vit_outcomes = pd.concat([df['image_id'], pd.DataFrame(vit_preds)], axis=1).sort_values(['image_id'])

In [None]:
final_preds = vit_outcomes.drop('image_id', axis=1).to_numpy().argmax(1)

In [None]:
submit = pd.DataFrame({'image_id': df['image_id'].values, 'label': final_preds})
submit.to_csv('submission.csv', index=False)