## Install dependencies

In [None]:
!pip install ../input/timm-pytorch-image-models/pytorch-image-models-master/

## Libraries

In [None]:
import os
import random
import time
import warnings

import albumentations as A
import cv2
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

from pathlib import Path
from typing import List

from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

## Utilities

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def prepare_model_for_inference(model, path: Path):
    if not torch.cuda.is_available():
        ckpt = torch.load(path, map_location="cpu")
    else:
        ckpt = torch.load(path)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    return model

## Data 

In [None]:
# For on-site competition, we need to change this part!!!
DATADIR = Path("../input/riadd-public-test/Evaluation_Set")
MODELDIR = Path("../input/riadd-trained-weights/")
all_pngs = list(DATADIR.glob("*.png"))
all_pngs.sort(key=lambda x: int(x.name.split(".")[0]))
all_png_names = [path.name for path in all_pngs]
print(all_png_names[:5])

In [None]:
test_df = pd.DataFrame({
    "ID": [int(n.split(".")[0]) for n in all_png_names]
})
test_df.head()

## Dataset

In [None]:
def crop_image_from_gray(image: np.ndarray, threshold: int = 7):
    if image.ndim == 2:
        mask = image > threshold
        return image[np.ix_(mask.any(1), mask.any(0))]
    elif image.ndim == 3:
        gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    mask = gray_image > threshold

    check_shape = image[:, :, 0][np.ix_(mask.any(1), mask.any(0))].shape[0]
    if (check_shape == 0):
        return image
    else:
        image1 = image[:, :, 0][np.ix_(mask.any(1), mask.any(0))]
        image2 = image[:, :, 1][np.ix_(mask.any(1), mask.any(0))]
        image3 = image[:, :, 2][np.ix_(mask.any(1), mask.any(0))]

        image = np.stack([image1, image2, image3], axis=-1)
        return image


def center_crop(image: np.ndarray, ar: float):
    h, w, _ = image.shape
    new_h = int(ar * w)
    start = (h - new_h) // 2
    return image[start:start + new_h, :, :]

In [None]:
class TestDatasetV1(torchdata.Dataset):
    def __init__(self, df: pd.DataFrame, datadir: Path, transform=None, center_crop=False):
        self.df = df
        self.filenames = df["ID"].values
        self.datadir = datadir
        self.transform = transform
        self.center_crop = center_crop

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index: int):
        filename = self.filenames[index]
        path = self.datadir / f"{filename}.png"
        image = cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        if h == 1424 and w == 2144:
            camera = "C1"
        elif h == 2848 and w == 4288:
            camera = "C2"
        else:
            camera = "C3"

        if self.center_crop:
            image = crop_image_from_gray(image)
            if camera != "C2":
                image = center_crop(image, ar=0.834)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]
        return {
            "ID": filename,
            "image": image
        }


class TestDataset(torchdata.Dataset):
    def __init__(self, df: pd.DataFrame, datadir: Path, transform=None, center_crop=True):
        self.df = df
        self.filenames = df["ID"].values
        self.datadir = datadir
        self.transform = transform
        self.center_crop = center_crop

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index: int):
        filename = self.filenames[index]
        path = self.datadir / f"{filename}.png"
        image = cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        if h == 1424 and w == 2144:
            camera = "C1"
        elif h == 2848 and w == 4288:
            camera = "C2"
        else:
            camera = "C3"

        if self.center_crop:
            image = crop_image_from_gray(image)
            if camera != "C2":
                image = center_crop(image, ar=0.834)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]
        return {
            "ID": filename,
            "image": image
        }

In [None]:
def get_transforms(img_size: int, mode="train"):
    if mode == "train":
        return A.Compose([
            A.RandomResizedCrop(
                height=img_size,
                width=img_size,
                scale=(0.9, 1.1),
                ratio=(0.9, 1.1),
                p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1,
                scale_limit=0.1,
                rotate_limit=180,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                mask_value=0,
                p=0.5),
            A.RandomBrightnessContrast(
                brightness_limit=0.1, contrast_limit=0.1, p=0.5),
            A.HueSaturationValue(
                hue_shift_limit=5,
                sat_shift_limit=5,
                val_shift_limit=5,
                p=0.5),
            A.Resize(img_size, img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.4406],
                std=[0.229, 0.224, 0.225],
                always_apply=True),
            ToTensorV2()
        ])
    elif mode == "valid":
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.4406],
                std=[0.229, 0.224, 0.225],
                always_apply=True),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(img_size, img_size),
            A.Normalize(
                mean=[0.485, 0.456, 0.4406],
                std=[0.229, 0.224, 0.225],
                always_apply=True),
            ToTensorV2()
        ])

## Model

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def gem(x: torch.Tensor, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + f"(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"


class TimmModel(nn.Module):
    def __init__(self, base_model_name="tf_efficientnet_b0_ns", pooling="GeM", pretrained=True, num_classes=24):
        super().__init__()
        self.base_model = timm.create_model(base_model_name, pretrained=pretrained)
        if hasattr(self.base_model, "fc"):
            in_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Linear(in_features, num_classes)
        elif hasattr(self.base_model, "classifier"):
            in_features = self.base_model.classifier.in_features
            self.base_model.classifier = nn.Linear(in_features, num_classes)
        elif hasattr(self.base_model, "head"):
            in_features = self.base_model.head.fc.in_features
            self.base_model.head.fc = nn.Linear(in_features, num_classes)
        else:
            raise NotImplementedError

        if pooling == "GeM":
            if hasattr(self.base_model, "head"):
                self.base_model.head.global_pool = GeM()
            else:
                self.base_model.avg_pool = GeM()

        self.init_layer()

    def init_layer(self):
        if hasattr(self.base_model, "fc"):
            init_layer(self.base_model.fc)
        elif hasattr(self.base_model, "classifier"):
            init_layer(self.base_model.classifier)
        elif hasattr(self.base_model, "head"):
            init_layer(self.base_model.head.fc)
        else:
            raise NotImplementedError

    def forward(self, x):
        return self.base_model(x)

## Configuration

In [None]:
DATASET_TYPE = {
    "012": TestDatasetV1, "014": TestDatasetV1, "015": TestDatasetV1,
    "022": TestDataset, "023": TestDataset, "026": TestDataset,
    "027": TestDataset, "030": TestDataset, "038": TestDataset,
    "039": TestDataset, "040": TestDataset
}


IMAGE_SIZE = {
    "012": 320, "014": 320, "015": 480, "022": 480, "023": 320,
    "026": 640, "027": 320, "030": 320, "038": 384, "039": 456,
    "040": 384
}


MODELS = {
    "012": {
        "base_model_name": "tf_efficientnet_b0_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "014": {
        "base_model_name": "tf_efficientnet_b1_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "015": {
        "base_model_name": "tf_efficientnet_b0_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "022": {
        "base_model_name": "tf_efficientnet_b0_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "023": {
        "base_model_name": "tf_efficientnet_b0_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "026": {
        "base_model_name": "tf_efficientnet_b0_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "027": {
        "base_model_name": "dm_nfnet_f0",
        "pooling": "",
        "pretrained": False,
        "num_classes": 29
    },
    "030": {
        "base_model_name": "tf_efficientnet_b3_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "038": {
        "base_model_name": "tf_efficientnet_b3_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "039": {
        "base_model_name": "tf_efficientnet_b5_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    },
    "040": {
        "base_model_name": "tf_efficientnet_b4_ns",
        "pooling": "GeM",
        "pretrained": False,
        "num_classes": 29
    }
}


loader_params = {
    "batch_size": 32,
    "num_workers": 4,
    "shuffle": False
}


target_columns = [
    "Disease_Risk", "DR", "ARMD", "MH", "DN",
    "MYA", "BRVO", "TSLN", "ERM", "LS", "MS",
    "CSR", "ODC", "CRVO", "TV", "AH", "ODP",
    "ODE", "ST", "AION", "PT", "RT", "RS", "CRS",
    "EDN", "RPEC", "MHL", "RP", "OTHER"
]

In [None]:
all_models = [
    "012", "014", "015", "022", "023",
    "026", "027", "030", "038", "039", "040"
]

## Inference loop

In [None]:
t0 = time.time()

set_seed(1213)
device = get_device()

model_predictions = {}
for model_name in all_models:
    img_size = IMAGE_SIZE[model_name]
    dataset_cls = DATASET_TYPE[model_name]
    test_dataset = dataset_cls(test_df, datadir=DATADIR, transform=get_transforms(img_size, mode="test"))
    test_loader = torchdata.DataLoader(test_dataset, **loader_params)
    
    model0 = TimmModel(**MODELS[model_name])
    model1 = TimmModel(**MODELS[model_name])
    model2 = TimmModel(**MODELS[model_name])
    model3 = TimmModel(**MODELS[model_name])
    model4 = TimmModel(**MODELS[model_name])
    
    model0 = prepare_model_for_inference(model0, MODELDIR / f"best0_{model_name}.pth").to(device)
    model1 = prepare_model_for_inference(model1, MODELDIR / f"best1_{model_name}.pth").to(device)
    model2 = prepare_model_for_inference(model2, MODELDIR / f"best2_{model_name}.pth").to(device)
    model3 = prepare_model_for_inference(model3, MODELDIR / f"best3_{model_name}.pth").to(device)
    model4 = prepare_model_for_inference(model4, MODELDIR / f"best4_{model_name}.pth").to(device)
    
    predictions = []
    ids = []
    for batch in tqdm(test_loader, desc=f"Model: {model_name}"):
        input_ = batch["image"].to(device)
        id_ = batch["ID"]
        ids.extend(id_.cpu().numpy().tolist())
        with torch.no_grad():
            out0 = torch.sigmoid(model0(input_).detach())
            out1 = torch.sigmoid(model1(input_).detach())
            out2 = torch.sigmoid(model2(input_).detach())
            out3 = torch.sigmoid(model3(input_).detach())
            out4 = torch.sigmoid(model4(input_).detach())
            out = (out0 / 5 + out1 / 5 + out2 / 5 + out3 / 5 + out4 / 5).cpu().numpy()
            predictions.append(out)
            
        pred_array = np.concatenate(predictions, axis=0)
        pred_df = pd.DataFrame({
            "ID": ids
        })
        pred_df = pd.concat([pred_df, pd.DataFrame(pred_array, columns=target_columns)], axis=1)

    model_predictions[model_name] = pred_df
    
elapsed = time.time() - t0
elapsed_min = int(elapsed // 60)
elapsed_sec = elapsed % 60
print(f"Elapsed time: {elapsed_min}min {elapsed_sec:.4f}seconds.")

In [None]:
model_predictions["012"]

## Blending

In [None]:
weights = {
    "Disease_Risk": [
        0.056106002705367315,
        0.0705532053116401,
        0.04997655821435919,
        0.18143715276113276,
        0.03404688654322489,
        0.1525338006486202,
        0.10291028003467365,
        0.03781232312031158,
        0.047585711586795906,
        0.19312290400942153,
        0.0739151750644528
    ],
    "DR": [
        0.07198363177835458,
        0.00296027390362025,
        0.012653951328298996,
        0.032170520284754256,
        0.04279789377360028,
        0.09263422763553994,
        0.16614894387991772,
        0.18296522800764742,
        0.03542622767823671,
        0.2598184636697298,
        0.10044063806030003
    ],
    "ARMD": [
        0.057843489172587775,
        0.20248546755195435,
        0.02373066185550471,
        0.03295197047227677,
        0.03320676673880146,
        0.0401967049626236,
        0.003720803304882151,
        0.19942890821432857,
        0.2417245089032871,
        0.14174976312725734,
        0.022960955696496036
    ],
    "MH": [
        0.18063768877183053,
        0.12812994889046606,
        0.025071869951291095,
        0.005656853197003577,
        0.03304591463137401,
        0.025032441154225468,
        0.11804684372971683,
        0.1443760855567267,
        0.040293048980509026,
        0.15860267553985744,
        0.1411066295969991
    ],
    "DN": [
        0.005555578435664791,
        0.009461843717928782,
        0.02619002967485514,
        0.24672396271576713,
        0.09831273817841059,
        0.19991777433766447,
        0.12245972479660357,
        0.06283758943073323,
        0.03495926801965508,
        0.061713026745616115,
        0.131868463947101
    ],
    "MYA": [
        0.012658375475527782,
        0.14876486352517537,
        0.03530771434213804,
        0.19634499647164505,
        0.05017875645252962,
        0.016782261343092974,
        0.22460975976495337,
        0.011413311089070968,
        0.18560344913667234,
        0.02972406682598728,
        0.08861244557320717
    ],
    "BRVO": [
        0.23370240056638322,
        0.006032935764162051,
        0.003794645292139292,
        0.022909738502784147,
        0.055597755279284206,
        0.2053972692562277,
        0.013285495756493383,
        0.07288308882556965,
        0.031124035526781182,
        0.1846973554073865,
        0.17057527982278875
    ],
    "TSLN": [
        0.037977535980047494,
        0.08343754676179876,
        0.019029574855066095,
        0.08669341941694406,
        0.15803134801908564,
        0.05401161985098037,
        0.22472671926196883,
        0.022968059342951406,
        0.21156717029289818,
        0.02826659865645776,
        0.07329040756180134
    ],
    "ERM": [
        0.02987969639610825,
        0.11601818286443207,
        0.004368242077352041,
        0.010978346487480818,
        0.15686054755221301,
        0.18812373995743054,
        0.04906336998625799,
        0.20077657513546693,
        0.042369353245629224,
        0.04271018680379536,
        0.15885175949383373
    ],
    "LS": [
        0.013475172072800584,
        0.06761712055928822,
        0.005202017030359687,
        0.013836877201252509,
        0.015488355609524464,
        0.15936495891531166,
        0.133979239273297,
        0.02743590060208177,
        0.21341277637556347,
        0.18079331478114724,
        0.16939426757937331
    ],
    "MS": [
        0.18603178688696712,
        0.19682821622652283,
        0.10488846154101043,
        0.1416172053648892,
        0.06736749790196353,
        0.01684103130532239,
        0.010080807749732506,
        0.032777056198835856,
        0.03797495350611392,
        0.04308032678273061,
        0.1625126565359116
    ],
    "CSR": [
        0.057350920206716206,
        0.061296793225116074,
        0.12545705584565167,
        0.1372387652767952,
        0.06068788930201581,
        0.005560240141343137,
        0.005004647684992925,
        0.16087871903243436,
        0.10312285229510063,
        0.10777370944001727,
        0.17562840754981668
    ],
    "ODC": [
        0.01195084027988455,
        0.07675929050051418,
        0.0030758554601049816,
        0.0033816356769156663,
        0.06464306918256114,
        0.16868674206947862,
        0.11464176461383777,
        0.08480850394403597,
        0.04859804254378671,
        0.23296727277207402,
        0.19048698295680636
    ],
    "CRVO": [
        0.12102894112999622,
        0.0031509059245401704,
        0.13679362364474282,
        0.09899274595582706,
        0.009164664100228311,
        0.022047566613933846,
        0.15906731196113055,
        0.1269106862968374,
        0.12420508313400691,
        0.031122954325171627,
        0.16751551691358507
    ],
    "TV": [
        0.17062068225925456,
        0.0692155548789209,
        0.04804573175048877,
        0.019901312958279273,
        0.038990165222716956,
        0.18476438490637934,
        0.03408905613670613,
        0.17766772858592783,
        0.07476804442359734,
        0.027603256063401724,
        0.15433408281432714
    ],
    "AH": [
        0.0033388256708909688,
        0.0541692934828976,
        0.08612813146312515,
        0.03136878799982958,
        0.16733652087812734,
        0.028659361987300472,
        0.18154730215053777,
        0.06875072775404474,
        0.17131128227419462,
        0.1857738141109412,
        0.021615952228110465
    ],
    "ODP": [
        0.05305579799351704,
        0.10626602615312912,
        0.07622993045776295,
        0.029767343899165695,
        0.021463328952529205,
        0.00708842717739818,
        0.08526899815623941,
        0.04301537901825068,
        0.1945457784972824,
        0.23710910477716382,
        0.14618988491756144
    ],
    "ODE": [
        0.006645959887180597,
        0.22230155846173286,
        0.020712902820439652,
        0.12460869515530355,
        0.11614078629843315,
        0.036919185881771406,
        0.061225201067100714,
        0.08466864420934254,
        0.17641524324191088,
        0.04509678392179723,
        0.10526503905498734
    ],
    "ST": [
        0.05579449519377708,
        0.14607751358489285,
        0.03571624663130894,
        0.09789365139466331,
        0.002355030098990704,
        0.029921377114556013,
        0.07055169114960552,
        0.15532684929337387,
        0.16655048630466712,
        0.17945188215469346,
        0.06036077707947124
    ],
    "AION": [
        0.1253097445158149,
        0.010609858617663016,
        0.02414591511047431,
        0.02267531836833228,
        0.18376319058087542,
        0.1879722453959952,
        0.030761469108309188,
        0.170634687147818,
        0.005135651616286887,
        0.15506477455432696,
        0.08392714498410385
    ],
    "PT": [
        0.04257503481566441,
        0.22939887369083786,
        0.0311583003444886,
        0.039664902221607026,
        0.04199424916904942,
        0.05360881339210934,
        0.013725477194634123,
        0.008448911423130704,
        0.23817423893661696,
        0.12443357823568726,
        0.17681762057617426
    ],
    "RT": [
        0.0634731983460384,
        0.16038012253823364,
        0.008558991261568201,
        0.021056623924998723,
        0.13048489441904895,
        0.09263821493358795,
        0.006869220305358243,
        0.06525941509926471,
        0.21613863283704512,
        0.10461268665414064,
        0.13052799968071538
    ],
    "RS": [
        0.08804964471352277,
        0.03859347885588542,
        0.05109940389344141,
        0.16556706594047788,
        0.021829330715815054,
        0.14642460936672982,
        0.007882200489736717,
        0.19488223732391807,
        0.08861198275562883,
        0.0963850475655661,
        0.10067499837927785
    ],
    "CRS": [
        0.07926642442851686,
        0.05880524826639643,
        0.1297631329688441,
        0.17953766475290905,
        0.05394588896044773,
        0.1486489446572299,
        0.0028443503450392505,
        0.17732452101836324,
        0.13901813095338925,
        0.018398667696572262,
        0.012447025952291944
    ],
    "EDN": [
        0.11035786480778485,
        0.013689350044015053,
        0.19969147311488267,
        0.16144229238216626,
        0.0015713391526851511,
        0.06222407980433003,
        0.14387780194835653,
        0.06996229454262844,
        0.19933498437765143,
        0.0031547262880191667,
        0.034693793537480505
    ],
    "RPEC": [
        0.10560357351808497,
        0.010968325783864087,
        0.07659658526818408,
        0.15456612319864108,
        0.05085413397046142,
        0.0027468535400073774,
        0.1097888922714797,
        0.04402210839258307,
        0.17352724754791207,
        0.09133712572970676,
        0.1799890307790754
    ],
    "MHL": [
        0.005278870991800131,
        0.008465487750850442,
        0.04883094612939742,
        0.1714560178166613,
        0.016440023199584004,
        0.007782228844155979,
        0.18643351113327544,
        0.10275259136556744,
        0.07631094005311971,
        0.20161979244347814,
        0.17462959027211003
    ],
    "RP": [
        0.046659215896313236,
        0.13550470865806924,
        0.13729709238051147,
        0.12828666773359546,
        0.11580858932263724,
        0.055451832302981646,
        0.046575480858785254,
        0.11132009991747341,
        0.03460867738356404,
        0.09163839330279117,
        0.09684924224327789
    ],
    "OTHER": [
        0.11277343565542515,
        0.22937878979814655,
        0.0037428550518870275,
        0.01902220378335148,
        0.01716102632074993,
        0.047214018465549436,
        0.3387451735681028,
        0.08186764813762831,
        0.0967134016431368,
        0.0024065151482787436,
        0.05097493242774383
    ]
}

In [None]:
submission_df = model_predictions["012"].copy()
for column in target_columns:
    submission_df[column] = 0.0
    for i, model_name in enumerate(all_models):
        submission_df[column] += weights[column][i] * model_predictions[model_name][column]

In [None]:
submission_df.to_csv("chizu & arai & okada_results.csv", index=False)
pd.read_csv("chizu & arai & okada_results.csv")