In [None]:
# ===============
# Unet+Resnet50
# ===============
import os
import gc
import json
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from contextlib import contextmanager
from sklearn.metrics import cohen_kappa_score
import torch
from torch.utils.data import DataLoader

import sys

sys.path.append("../input/aptos-src")
from util import seed_torch
from metric import OptimizedKappaRounder
from model import ResNet
from logger import setup_logger, LOGGER


import os
import cv2
import torch
import random
import pydicom
import numpy as np
from torch.utils.data import Dataset


class APTOSDatasetTest(Dataset):

    def __init__(self,
                 df,
                 img_size,
                 image_path,
                 id_colname="id_code",
                 target_colname="diagnosis",
                 transforms=None,
                 means=[0.485, 0.456, 0.406],
                 stds=[0.229, 0.224, 0.225]
                 ):
        self.df = df
        self.img_size = img_size
        self.image_path = image_path
        self.transforms = transforms
        self.means = np.array(means)
        self.stds = np.array(stds)
        self.id_colname = id_colname
        self.target_colname = target_colname

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        cur_idx_row = self.df.iloc[idx]
        img_id = cur_idx_row[self.id_colname]
        img_name = img_id + ".png"
        img_path = os.path.join(self.image_path, img_name)

        img = cv2.imread(img_path)
        img = cv2.resize(img, (self.img_size, self.img_size))

        if self.transforms is not None:
            augmented = self.transforms(image=img)
            img = augmented['image']

        img = img / 255
        img -= self.means
        img /= self.stds
        img = img.transpose((2, 0, 1))

        return torch.Tensor(img), img_id


def test(model, valid_loader, device, batch_size, last=False):
    model.eval()
    ids = []
    preds_cat = []
    with torch.no_grad():

        for step, (features, img_id) in enumerate(valid_loader):
            features = features.to(device)

            logits = model(features)
            
            ids.extend(img_id)
            preds_cat.append(logits)

        ids = np.array(ids).reshape(-1)
        all_preds = torch.cat(preds_cat).float().cpu().numpy()

    return all_preds, ids


# ===============
# Constants
# ===============
DATA_DIR = "../input/aptos2019-blindness-detection/"
IMAGE_PATH = "../input/aptos2019-blindness-detection/test_images/"
LOGGER_PATH = "log.txt"
TRAIN_PATH = os.path.join(DATA_DIR, "train.csv")
TEST_PATH = os.path.join(DATA_DIR, "test.csv")
ID_COLUMNS = "id_code"
TARGET_COLUMNS = "diagnosis"

# ===============
# Settings
# ===============
seed = 0
device = "cuda:0"
img_size = 256
batch_size = 64
model_path = "../input/exp1-resnet50/exp1_resnet50_fold0.pth"

setup_logger(out_file=LOGGER_PATH)
seed_torch(seed)


@contextmanager
def timer(name):
    t0 = time.time()
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s')


with timer('load data'):
    df = pd.read_csv(TEST_PATH)

with timer('preprocessing'):
    test_augmentation = None
    test_dataset = APTOSDatasetTest(df, img_size, IMAGE_PATH, id_colname=ID_COLUMNS,
                               transforms=test_augmentation)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    del df
    gc.collect()

with timer('create model'):
    model = ResNet(num_classes=1, pretrained=None)
    model.load_state_dict(torch.load(model_path))
    model.to(device)

with timer('predict'):
    test_pred, ids = test(model, test_loader, device, batch_size)

    optR = OptimizedKappaRounder()
    coefficients = [0.5867620312940832, 1.3155473393431514, 2.5984588076287, 3.1550712728719525]
    test_pred = optR.predict(test_pred, coefficients)
    
    sub_df = pd.DataFrame({ID_COLUMNS: ids, TARGET_COLUMNS: test_pred})
    LOGGER.info(sub_df.head())
    sub_df.to_csv('submission.csv', index=False)