In [None]:
import cv2
from collections import defaultdict
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from pathlib import Path
from PIL import Image, ImageOps
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm
import torch
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from albumentations import Compose, OneOf
from albumentations.augmentations.transforms import *
from albumentations.pytorch.transforms import ToTensorV2 

In [None]:
repeat_preds = 2
input_size = (512, 512)
test_img_path = '../input/cassava-leaf-disease-classification/test_images'

In [None]:
model_fnames = [
    #'../input/cassava-notebook-34-model/effnet_epoch_12.pickle', # valid acc 0.8924485125858124 
    #'../input/cassava-run-1612850859606531-model/effnet_epoch_19.pickle',
    # Removing this model - was trained on an old version of the transforms
    #'../input/cassava-final-ensemble-models/train_2021-02-12_18-03-44_model_epoch_10.pickle', # valid acc 0.8826849733028223
    #'../input/cassava-final-ensemble-models/train_2021-02-12_10-12-59_model_epoch_8.pickle', # valid acc 0.8858886346300534
    '../input/cassava-final-ensemble-models/train_2021-02-14_10-23-02_model_epoch_9.pickle', # valid acc 0.8987032799389779
    '../input/cassava-final-ensemble-models/train_2021-02-14_21-33-31_model_epoch_11.pickle', # valid acc 0.8994660564454615
    '../input/cassava-final-ensemble-models/train_2021-02-15_21-38-02_model_epoch_9.pickle', # valid acc 0.8890922959572846
    #'../input/cassava-final-ensemble-models/train_2021-02-17_11-31-19_model_epoch_8.pickle', # valid acc 0.883905415713196
    '../input/cassava-train-20210217-214040-model-epoch-14/train_2021-02-17_21-40-40_model_epoch_14.pickle' # valid acc 0.8903127383676582
]
models = [torch.load(x) 
          for x in model_fnames]
models = [x.to(device) for x in models]
models = [x.eval() for x in models]

In [None]:
predict_files = [x for x in Path('../input/cassava-leaf-disease-classification/test_images').iterdir()
                 if x.is_file()]

In [None]:
test_tfms_rrc = Compose([
    RandomResizedCrop(
                *input_size,
                always_apply=True, scale=(.75, 1.0),
                interpolation=cv2.INTER_CUBIC,
                ratio=(1, 1),
                p=1.0),
    Rotate(limit=[-45, 45], interpolation=cv2.INTER_LANCZOS4),
    Transpose(),
    Flip(),
    Normalize(),
    ToTensorV2()
])

test_tfms_cc = Compose([
    Rotate(limit=[-45, 45], interpolation=cv2.INTER_LANCZOS4, p=0.3),
    CenterCrop(
            *input_size,
            always_apply=True,
            p=1.0),
    #Transpose(),
    Flip(),
    Normalize(),
    ToTensorV2()
])

In [None]:
predictions_list = defaultdict(list)
predictions_agg = {}

for filepath in predict_files:
    image_id = filepath.name
    pimage = Image.open(filepath).convert('RGB')
    image = np.array(pimage)
    
    # Predict on images using rrc
    for model in models:
        image_rrc = test_tfms_rrc(image=image)['image'] # image image image
        image_expand = torch.unsqueeze(image_rrc, axis=0).to(device)
        with torch.no_grad():
            out = model(image_expand)
            distribution = softmax(out, dim=1)
            predictions_list[image_id].append(distribution)
    

    # Predict on images using center crop with presizing to 600
    image_rs = cv2.resize(image, (600, 600), cv2.INTER_CUBIC)
    for _ in range(repeat_preds):
        for model in models:
            image_cc = test_tfms_cc(image=image_rs)['image'] # image image image
            image_expand = torch.unsqueeze(image_cc, axis=0).to(device)
            with torch.no_grad():
                out = model(image_expand)
                distribution = softmax(out, dim=1)
                predictions_list[image_id].append(distribution)
            
for image_id, prediction_list in predictions_list.items():
    pred = torch.stack(prediction_list).sum(axis=0).cpu().numpy()
    print(image_id, pred)
    predictions_agg[image_id] = np.argmax(pred)

In [None]:
with open('submission.csv', 'w+') as submission:
    submission.write('image_id,label\n')
    for img_id, prediction in predictions_agg.items():
        submission.write(f'{img_id},{prediction}\n')

In [None]:
!cat submission.csv