In [1]:
test_images_path = "/kaggle/input/cassava-leaf-disease-classification/test_images"
timm_path = "/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master"
folds_path = "../input/leaf-cherry"

In [2]:
import sys; sys.path.extend([timm_path])

In [3]:
from tqdm import tqdm
from pathlib import Path

import numpy as np
import pandas as pd
import cv2
from scipy.special import softmax

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [4]:
tiny=False
class LeafDataset(Dataset):
    """Cassava Leaf Disease Classification dataset."""

    def __init__(self, img_dir, transform=None, tiny=False):
        self.img_dir = img_dir if isinstance(img_dir, Path) else Path(img_dir)
        self.transform = transform
        self.tiny = tiny
        self.fnames = np.array([img.name for img in img_dir.glob("*.jpg")])
        if self.tiny:
            self.fnames = self.fnames[:TINY_SIZE]
        self.labels = None
        self.dataset_len = len(self.fnames)

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
        # img = io.imread(self.img_dir / self.fnames[idx])
        img = cv2.imread(str(self.img_dir / self.fnames[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, idx

In [5]:
arch = "tf_efficientnet_b4_ns"
tta_img_size = 446
batch_size = 32

checkpoint_dir = Path(folds_path)
checkpoints = {
    0: "final10_fold0.pth",
    1: "final10_fold1.pth",
    2: "final20_fold2.pth",
    3: "final20_fold3.pth",
    4: "final20_fold4.pth"
}

num_folds = 5
num_tta = 12
num_classes = 5

tta_transforms = A.Compose([
    A.Resize(tta_img_size, tta_img_size),
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=1.0),
    A.RandomBrightnessContrast(brightness_limit=0.03, contrast_limit=0.03, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
    ToTensorV2()
])

dset = LeafDataset(Path(test_images_path), transform=tta_transforms, tiny=tiny)

model = timm.create_model(model_name=arch, num_classes=5, pretrained=False)
device = torch.device("cuda")

torch.cuda.empty_cache()
pbar = tqdm(total=num_folds*num_tta)
logits_all = np.zeros((num_folds, num_tta, len(dset), num_classes), dtype=float)
for fold in range(num_folds):
    checkpoint = checkpoint_dir / checkpoints[fold]
    checkpoint_dict = torch.load(checkpoint)
    model.load_state_dict(checkpoint_dict['model_state_dict'])

    model = model.to(device)
    model.eval()

    dataloader = DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    with torch.no_grad():
        for i_tta in range(num_tta):
            logits = torch.zeros((len(dset), 5), dtype=float, device=device)
            i = 0
            for imgs, idxs in dataloader:
                imgs = imgs.to(device)
                bs = imgs.shape[0]
                logits[i:(i + bs), :] = model.forward(imgs)
                i += bs
            pbar.update(1)
            logits = logits[:i]
            logits_all[fold, i_tta, :, :] = logits.cpu().numpy()

probs = softmax(logits_all, axis=-1)
probs = np.mean(probs.reshape(-1, len(dset), num_classes), axis=0)
preds = probs.argmax(axis=-1)

100%|██████████| 60/60 [00:34<00:00,  4.95it/s]

In [6]:
submission_df = pd.DataFrame({"image_id": dset.fnames, "label": preds})

In [7]:
submission_df.to_csv("submission.csv", index=False)