In [None]:
package_path = '../input/pytorchimagemodels'
import sys; sys.path.append(package_path)

In [None]:
import os
import random
import cv2
import timm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import albumentations.pytorch as Apy

import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader

from tqdm import tqdm

In [None]:
config = {
    'fold_num': 1,
    'seed': 719,
    'model_arch': 'resnext50d_32x4d',
    'img_size': 512,
    'valid_bs': 256,
    'num_workers': 4,
    'accum_iter': 1,
    'verbose_step': 1,
    'device': 'cuda:0' if torch.cuda.is_available() else "cpu",
}

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.head()

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    return im_bgr[:, :, ::-1]

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df, data_root, transforms=None, output_label=True):
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.output_label = output_label
    
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df.iloc[index]['label']
          
        path = "{}/{}".format(self.data_root, self.df.iloc[index]['image_id'])
        
        img  = get_img(path)
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        if self.output_label == True:
            return img, target
        else:
            return img

In [None]:
def get_inference_transforms():
    return A.Compose([
            A.Resize(config['img_size'], config['img_size']),
            A.RandomResizedCrop(config['img_size'], config['img_size']),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            Apy.ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, n_class)
        
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
def inference_one_epoch(model, data_loader, device):
    model.eval()

    image_preds_all = []
    
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
        
    
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

In [None]:
test = pd.DataFrame()
test['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
test_ds = CassavaDataset(test, '../input/cassava-leaf-disease-classification/test_images/', transforms=get_inference_transforms(), output_label=False)

tst_loader = torch.utils.data.DataLoader(
        test_ds, 
        batch_size=config['valid_bs'],
        num_workers=config['num_workers'],
        shuffle=False,
        pin_memory=False,
    )

device = torch.device(config['device'])
model = CassvaImgClassifier(config['model_arch'],5, pretrained=False).to(device)

tst_preds = []

for fold in range(config['fold_num']):
    model.load_state_dict(torch.load('../input/cassava-leave-disease/{}_fold_{}'.format(config['model_arch'], fold)))
    with torch.no_grad():
        tst_preds += [inference_one_epoch(model, tst_loader, device)]
            
tst_preds = np.mean(tst_preds, axis=0)

del model
torch.cuda.empty_cache()

In [None]:
test['label'] = np.argmax(tst_preds, axis=1)
test.to_csv('submission.csv', index=False)