In [1]:
folds_path = "submission/final10/"

In [2]:
from tqdm import tqdm
from pathlib import Path

import numpy as np
import pandas as pd
import cv2
from scipy.special import softmax

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [3]:
TINY_SIZE=100
tiny=True
class LeafDataset(Dataset):
    """Cassava Leaf Disease Classification dataset."""

    def __init__(self, img_dir, transform=None, tiny=False):
        self.img_dir = img_dir if isinstance(img_dir, Path) else Path(img_dir)
        self.transform = transform
        self.tiny = tiny
        self.fnames = np.array([img.name for img in img_dir.glob("*.jpg")])
        if self.tiny:
            self.fnames = self.fnames[:TINY_SIZE]
        self.labels = None
        self.dataset_len = len(self.fnames)

    def __len__(self):
        return self.dataset_len

    def __getitem__(self, idx):
        # img = io.imread(self.img_dir / self.fnames[idx])
        img = cv2.imread(str(self.img_dir / self.fnames[idx]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, idx

In [4]:
arch = "tf_efficientnet_b4_ns"
tta_img_size = 446
batch_size = 32

num_folds = 4
num_tta = 10
num_classes = 5

tta_transforms = A.Compose([
    A.Resize(tta_img_size, tta_img_size),
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90, p=1.0),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
    A.RGBShift(p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
    ToTensorV2()
])

dset = LeafDataset(Path("data/mock_test"), transform=tta_transforms, tiny=tiny)

model = timm.create_model(model_name=arch, num_classes=5, pretrained=False)
device = torch.device("cuda")

In [5]:
checkpoint_dir = Path(folds_path)
checkpoints = {
    0: "final10_fold0.pth",
    1: "final10_fold1.pth",
    2: "final10_fold2.pth",
    3: "final10_fold3.pth",
    4: "final10_fold4.pth"
}

torch.cuda.empty_cache()
pbar = tqdm(total=num_folds*num_tta)
logits_all = np.zeros((num_folds, num_tta, len(dset), num_classes), dtype=float)
for fold in range(num_folds):
    checkpoint = checkpoint_dir / checkpoints[fold]
    checkpoint_dict = torch.load(checkpoint)
    model.load_state_dict(checkpoint_dict['model_state_dict'])

    model = model.to(device)
    model.eval()

    dataloader = DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    with torch.no_grad():
        for i_tta in range(num_tta):
            logits = torch.zeros((len(dset), 5), dtype=float, device=device)
            i = 0
            for imgs, idxs in dataloader:
                imgs = imgs.to(device)
                bs = imgs.shape[0]
                logits[i:(i + bs), :] = model.forward(imgs)
                i += bs
            pbar.update(1)
            logits = logits[:i]
            logits_all[fold, i_tta, :, :] = logits.cpu().numpy()

probs1 = softmax(logits_all, axis=-1)
probs1 = np.mean(probs1.reshape(-1, len(dset), num_classes), axis=0)

100%|██████████| 40/40 [00:56<00:00,  1.40s/it]

In [6]:
whole_checkpoint = "submission/final10_whole/final10_whole.pth"

In [8]:
torch.cuda.empty_cache()
pbar = tqdm(total=num_tta)
logits_all = np.zeros((num_tta, len(dset), num_classes), dtype=float)

checkpoint_dict = torch.load(whole_checkpoint)
model.load_state_dict(checkpoint_dict['model_state_dict'])

model = model.to(device)
model.eval()

dataloader = DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

with torch.no_grad():
    for i_tta in range(num_tta):
        logits = torch.zeros((len(dset), 5), dtype=float, device=device)
        i = 0
        for imgs, idxs in dataloader:
            imgs = imgs.to(device)
            bs = imgs.shape[0]
            logits[i:(i + bs), :] = model.forward(imgs)
            i += bs
        pbar.update(1)
        logits = logits[:i]
        logits_all[i_tta, :, :] = logits.cpu().numpy()

probs2 = softmax(logits_all, axis=-1)
probs2 = np.mean(probs2, axis=0)

100%|██████████| 10/10 [00:13<00:00,  1.36s/it]

In [9]:
probs = (probs1 + probs2) / 2

In [10]:
preds = probs.argmax(axis=-1)

In [11]:
gold_df = pd.read_csv("data/test_df.csv")

In [12]:
np.sum(gold_df.set_index("image_id").loc[dset.fnames].values.reshape(-1) == preds)/gold_df.shape[0]

0.022201448936667447

In [7]:
submission_df = pd.DataFrame({"image_id": dset.fnames, "label": preds})

In [8]:
submission_df.to_csv("submission.csv", index=False)

In [9]:
# Exploration

In [10]:
probs2 = softmax(logits_all, axis=-1)
probs2 = np.median(probs2.reshape(-1, len(dset), num_classes), axis=0)
preds2 = probs2.argmax(axis=-1)

In [11]:
np.sum(gold_df.set_index("image_id").loc[dset.fnames].values.reshape(-1) == preds2)/gold_df.shape[0]

0.9366674456648749

# Log 
# 1 (wrong)
num_tta = 12
tta_transforms = A.Compose([
    A.Resize(tta_img_size, tta_img_size),
    A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=90, p=1.0),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
    A.RGBShift(p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0),
    ToTensorV2()
])

mean 0.9291890628651555
median 0.9298901612526291

# 2 fixed
0.9354989483524188
0.9354989483524188

# one tta
0.9340967515774714 (mean better)

# no tta
0.938 (mean better)

# only h flip, 5 tta
0.9366674456648749

# h flip, transpose, 10 tta

0.9364337462023837

# h flip, v flip, 10 tta
0.9347978499649451

# h flip, ssr
0.9331619537275064

# ssr with no rotation, hflip
just a bit better than no

# rbc and rbg shift
worse
0.9369011451273662