In [None]:
!pip install ../input/timm034/timm-0.3.4-py3-none-any.whl

In [None]:
#!python ../input/ttach-master/setup.py install

In [None]:
import time
import os
import timm
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
#from efficientnet_pytorch import model as enet
import albumentations
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
#import ttach as tta

# Config

In [None]:
image_folder = '../input/cassava-leaf-disease-classification/test_images/'

enet_type = 'cspresnext50'
image_size = 512
batch_size = 4
num_workers = 4
out_dim = 5

device = torch.device('cuda')

# Model

In [None]:
ls ../input/cdl-cspresnext50-512/ 

In [None]:
model_pths = [
    '../input/cdl-cspresnext50-512/light_best_model_fold0.pth',
    '../input/cdl-cspresnext50-512/light_best_model_fold1.pth',
    '../input/cdl-cspresnext50-512/light_best_model_fold2.pth',
    '../input/cdl-cspresnext50-512/light_best_model_fold3.pth',
    '../input/cdl-cspresnext50-512/light_best_model_fold4.pth',
            ]

In [None]:
class net(nn.Module):
    def __init__(self, model_name=enet_type, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        
        n_features = self.model.head.fc.in_features
        self.model.head.fc = nn.Linear(n_features, 5)

    def forward(self, x):
        output = self.model(x)
        return output

# Dataset

In [None]:
class LEAFDataset(Dataset):
    def __init__(self, folder, transforms=None):

        self.file_names = os.listdir(folder)
        self.transforms = transforms

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, index):
        image_id = self.file_names[index]
        
        image_file = os.path.join(image_folder, image_id)
        image = cv2.imread(image_file)
        image = image[:, :, ::-1]

        if self.transforms is not None:
            image = self.transforms(image=image)['image']

        return image, image_id

# Augmentations

In [None]:
transform = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize(),
    ToTensorV2()
])

In [None]:
# ====================================================
# inference
# ====================================================

res = []
for model_pth in model_pths:
    model = net(enet_type)
    model.load_state_dict(torch.load(model_pth))
    model.eval()
    model.to(device)
    
    test_dataset = LEAFDataset(image_folder, transforms=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

    single_model_probs = []
    image_ids_list = []
    with torch.no_grad():
        for idx, (images, image_ids) in enumerate(tqdm(test_loader)):
            outputs = model(images.to(device))
            pred = outputs.detach().softmax(1).cpu().numpy()
            single_model_probs.append(pred)
            image_ids_list.append(image_ids)
#             if idx == 5:
#                 break
        
    single_model_probs = np.concatenate(single_model_probs)
    image_ids_list = np.concatenate(image_ids_list)
    res.append(single_model_probs)
    
    del model
    torch.cuda.empty_cache()
    
res = sum(res) / len(model_pths)
probs = res.argmax(1)

In [None]:
sub = pd.DataFrame({'image_id': image_ids_list, 'label': probs});sub.head()

In [None]:
sub.to_csv('submission.csv', index=False)