In [None]:
import numpy as np
import pandas as pd
import torch
import os
import time
import copy
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
import torchvision
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm
from joblib import Parallel, delayed
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from scipy.stats import spearmanr
from sklearn.model_selection import KFold, StratifiedKFold
import math
import uuid
import cv2
import gc
import albumentations
from albumentations import torch as AT
from numba import jit
!pip install adabound
import adabound
from contextlib import contextmanager, redirect_stdout

from bisect import bisect_right
import numpy as np
from torch.optim import Optimizer

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

!pip install efficientnet_pytorch --upgrade
from efficientnet_pytorch import EfficientNet

device = torch.device("cuda:0")

class Logger(object):
    """Save a string line(s) to a file."""
    
    def __init__(self, file_path, mode='w', verbose=False):
        self.file_path = file_path
        self.verbose = verbose
        open(file_path, mode=mode)
        
    def append(self, line, print_line=None):
        if print_line or self.verbose:
            print(line)
        with open(self.file_path, 'a') as f:
            with redirect_stdout(f):
                print(line)   

In [None]:
class CyclicLR(object):
    """Sets the learning rate of each parameter group according to
    cyclical learning rate policy (CLR). The policy cycles the learning
    rate between two boundaries with a constant frequency, as detailed in
    the paper `Cyclical Learning Rates for Training Neural Networks`_.
    The distance between the two boundaries can be scaled on a per-iteration
    or per-cycle basis.
    Cyclical learning rate policy changes the learning rate after every batch.
    `batch_step` should be called after a batch has been used for training.
    To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`.
    This class has three built-in policies, as put forth in the paper:
    "triangular":
        A basic triangular cycle w/ no amplitude scaling.
    "triangular2":
        A basic triangular cycle that scales initial amplitude by half each cycle.
    "exp_range":
        A cycle that scales initial amplitude by gamma**(cycle iterations) at each
        cycle iteration.
    This implementation was adapted from the github repo: `bckenstler/CLR`_
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        base_lr (float or list): Initial learning rate which is the
            lower boundary in the cycle for eachparam groups.
            Default: 0.001
        max_lr (float or list): Upper boundaries in the cycle for
            each parameter group. Functionally,
            it defines the cycle amplitude (max_lr - base_lr).
            The lr at any cycle is the sum of base_lr
            and some scaling of the amplitude; therefore
            max_lr may not actually be reached depending on
            scaling function. Default: 0.006
        step_size (int): Number of training iterations per
            half cycle. Authors suggest setting step_size
            2-8 x training iterations in epoch. Default: 2000
        mode (str): One of {triangular, triangular2, exp_range}.
            Values correspond to policies detailed above.
            If scale_fn is not None, this argument is ignored.
            Default: 'triangular'
        gamma (float): Constant in 'exp_range' scaling function:
            gamma**(cycle iterations)
            Default: 1.0
        scale_fn (function): Custom scaling policy defined by a single
            argument lambda function, where
            0 <= scale_fn(x) <= 1 for all x >= 0.
            mode paramater is ignored
            Default: None
        scale_mode (str): {'cycle', 'iterations'}.
            Defines whether scale_fn is evaluated on
            cycle number or cycle iterations (training
            iterations since start of cycle).
            Default: 'cycle'
        last_batch_iteration (int): The index of the last batch. Default: -1
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = torch.optim.CyclicLR(optimizer)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         scheduler.batch_step()
        >>>         train_batch(...)
    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
    """

    def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3,
                 step_size=2000, mode='triangular', gamma=1.,
                 scale_fn=None, scale_mode='cycle', last_batch_iteration=-1):

        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(base_lr, list) or isinstance(base_lr, tuple):
            if len(base_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} base_lr, got {}".format(
                    len(optimizer.param_groups), len(base_lr)))
            self.base_lrs = list(base_lr)
        else:
            self.base_lrs = [base_lr] * len(optimizer.param_groups)

        if isinstance(max_lr, list) or isinstance(max_lr, tuple):
            if len(max_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} max_lr, got {}".format(
                    len(optimizer.param_groups), len(max_lr)))
            self.max_lrs = list(max_lr)
        else:
            self.max_lrs = [max_lr] * len(optimizer.param_groups)

        self.step_size = step_size

        if mode not in ['triangular', 'triangular2', 'exp_range'] \
                and scale_fn is None:
            raise ValueError('mode is invalid and scale_fn is None')

        self.mode = mode
        self.gamma = gamma

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode

        self.batch_step(last_batch_iteration + 1)
        self.last_batch_iteration = last_batch_iteration

    def batch_step(self, batch_iteration=None):
        if batch_iteration is None:
            batch_iteration = self.last_batch_iteration + 1
        self.last_batch_iteration = batch_iteration
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

    def _triangular_scale_fn(self, x):
        return 1.

    def _triangular2_scale_fn(self, x):
        return 1 / (2. ** (x - 1))

    def _exp_range_scale_fn(self, x):
        return self.gamma**(x)

    def get_lr(self):
        step_size = float(self.step_size)
        cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size))
        x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1)

        lrs = []
        param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs)
        for param_group, base_lr, max_lr in param_lrs:
            base_height = (max_lr - base_lr) * np.maximum(0, (1 - x))
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle)
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration)
            lrs.append(lr)
        return lrs

In [None]:
# this takes a while, you need to change code a bit to train without apex
!git clone https://github.com/NVIDIA/apex
!pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" apex/

In [None]:
!rm -r apex

In [None]:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier

In [None]:
!ls ../input/

In [None]:
from PIL import Image, ImageFile, ImageFilter
import cv2

def pil_to_cv2(image):
    return np.array(image).astype(np.uint8)


def get_train_image(rec, ds="train19", image_size=420):
    if ds == "train19":
        
        img_name = os.path.join('../input/aptos2019-blindness-detection/train_images', rec['id_code'] + '.png')
        diag = rec['diagnosis']
        src = '2019'
    elif ds == "test19":
        img_name = os.path.join('../input/aptos2019-blindness-detection/test_images', rec['id_code'] + '.png')
        diag = 0
        src = '2019'
    elif ds == "train15":
        img_name = os.path.join("../input/resized-2015-2019-blindness-detection-images/resized train 15", rec['image'] + '.jpg')
        diag = rec["level"]
        src = "2015"
    elif ds == "test15":
        img_name = os.path.join("../input/resized-2015-2019-blindness-detection-images/resized test 15", rec['image'] + '.jpg')
        diag = rec["level"]
        src = "2015"
    image = cv2.imread(img_name)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # center retina
    if img.ndim == 2:
        mask = img > 7
        img = img[np.ix_(mask.any(1), mask.any(0))]
    elif img.ndim == 3:
        mask = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) > 7
        check_shape = img[:,:,0][np.ix_(mask.any(1), mask.any(0))].shape[0]
        if check_shape != 0:
            image1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            image2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            image3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([image1,image2,image3],axis=-1)


    # scale radius
    scale = image_size
    
   # set cropping percentage
    height, width, _ = img.shape
    ratio = (width, height)
    if ratio == (640, 480):

        img = Image.fromarray(img)
        img = np.clip( (np.array(img, dtype=int) - \
                          np.array(img.filter(ImageFilter.GaussianBlur(radius=scale/30)), dtype=int)) \
                        * 4 + 128, 0, 255)        
        
        new_sz = int(640 * 1.15)

        delta_w = new_sz - img.shape[1]
        delta_h = new_sz - img.shape[0]
        top = delta_h//2
        side = delta_w//2

        img = cv2.copyMakeBorder(img, top, top, side, side, cv2.BORDER_CONSTANT, value=[128, 128, 128])
        #img = cv2.copyMakeBorder(img, top, top, side, side, cv2.BORDER_CONSTANT, value=[0, 0, 0])
        
        r = new_sz / 2
        s = scale * 1.0 / r
        img = cv2.resize(img.astype(np.uint8), (0,0), fx=s, fy=s)
        
    else:
        x = img[int(img.shape[0]/2),:,:].sum(1)
        r = (x>x.mean() / 10).sum() / 2
        s = scale * 1.0 / r
        img = cv2.resize(img, (0,0), fx=s, fy=s)

        # remove local mean color & pad
        #img = cv2.addWeighted(img, 4, cv2.GaussianBlur(img, (0,0), scale/30), -4, 128)
        img = Image.fromarray(img)
        img = np.clip( (np.array(img, dtype=int) - \
                          np.array(img.filter(ImageFilter.GaussianBlur(radius=scale/30)), dtype=int)) \
                        * 4 + 128, 0, 255)



    # pad
    height, width, _ = img.shape
    if height > width:
        padder = albumentations.Compose([
                    albumentations.augmentations.transforms.PadIfNeeded(
                        min_height=height, min_width=height, value=(128, 128, 128), border_mode=0, always_apply=True)])
        img = padder(image=img)
        img = img['image'] 
    elif width > height:
        padder = albumentations.Compose([
                    albumentations.augmentations.transforms.PadIfNeeded(
                        min_height=width, min_width=width, value=(128, 128, 128), border_mode=0, always_apply=True)])
        img = padder(image=img)
        img = img['image'] 

    # circle crop
    b = np.zeros(img.shape)
    cv2.circle(b, (img.shape[1]//2, img.shape[0]//2), int(scale*0.9), (1,1,1), -1, 8, 0)
    img = img*b+128*(1-b)

    img = cv2.resize(img, (image_size, image_size))
    #img = Image.fromarray(img.astype(np.uint8))
    img = img.astype(np.uint8)

    return {'image': img, 'diag': diag, 'src': src}



class RetinopathyDataset(Dataset):
    def __init__(self, transform, ds="train19", preload=False):
        if ds == "train19":
            df = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')
        elif ds == "test19":
            df = pd.read_csv('../input/aptos2019-blindness-detection/sample_submission.csv')
        elif ds == "train15":
            df = pd.read_csv('../input/resized-2015-2019-blindness-detection-images/labels/trainLabels15.csv')
        elif ds == "test15":
            df = pd.read_csv('../input/resized-2015-2019-blindness-detection-images/labels/testLabels15.csv')
        self.ds = ds
        self.df = df
        self.preload = preload
        if self.preload:
            self.data = Parallel(n_jobs=2, temp_folder="/tmp", max_nbytes=None, backend="multiprocessing")\
                        (delayed(get_train_image)(df.iloc[idx], ds) for idx in tqdm(df.index))
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        #image = self.transform(self.data[idx]['image'])
        if self.preload:
            rec = self.data[idx]
        else:
            rec = get_train_image(self.df.iloc[idx], self.ds)
        if self.transform is not None:
            image = self.transform(image=rec["image"])["image"]
            #print(image)
        else:
            image = rec["image"]
        target = np.zeros(5)
        target[rec['diag']] = 1
        return {'image': image, 'diag1':rec['diag'], 'diag2': target, 'src': rec["src"]}
        return self.data[idx]
 

In [None]:
def aug_image(is_infer=False, transform_type=0):
    if is_infer:
        if transform_type == 0:
            aug = albumentations.Compose([
                albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                AT.ToTensor()
            ])
        elif transform_type == 1:
            aug = albumentations.Compose([
                albumentations.HorizontalFlip(always_apply=True),
                albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                AT.ToTensor()
            ])            
        elif transform_type == 2:
            aug = albumentations.Compose([
                albumentations.VerticalFlip(always_apply=True),
                albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                AT.ToTensor()
            ])                        
        elif transform_type == 3:
            aug = albumentations.Compose([
                albumentations.VerticalFlip(always_apply=True),
                albumentations.HorizontalFlip(always_apply=True),
                albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                AT.ToTensor()
            ])            
    else:
        aug = albumentations.Compose([
            albumentations.Rotate(limit=(-120,120)),
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.OneOf([
                albumentations.CLAHE(clip_limit=2.0),
                albumentations.RandomBrightnessContrast(),
            ], p=0.3),            
            albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
            AT.ToTensor()
        ])
    return aug

# comment pickle loading and uncomment this if you want to change preprocessing
train_dataset_2019 = RetinopathyDataset(transform=aug_image(is_infer=False), ds="train19", preload=True)

# import pickle
# with open('../input/train-dataset-2019-v1/train_dataset_2019.p', "rb") as output_file:
#     train_dataset_2019 = pickle.load(output_file)
    
val_dataset_2019 = copy.copy(train_dataset_2019)

train_dataset_2019.transform = aug_image(is_infer=False)
val_dataset_2019.transform = aug_image(is_infer=True, transform_type=0)


In [None]:
import matplotlib.pyplot as plt
plt.imshow(val_dataset_2019[120]["image"].permute(1, 2, 0)  )

In [None]:
!ls ../input/b5_is=420_bs=24_lr=0.002_seed=123_v5

In [None]:
seed = 0
n_folds = 4
chunk_size = 1000
n_workers = 2

image_size = (420, 420)
n_epochs = 4
n_freeze = 0
patience = 5
model_weight = 15

bag_size = 3
n_tta = 4

batch_size = 8
gradient_accumulation = 4

lr = 5e-4
parallel = True

pretrain_path = '../input/b5_is=420_bs=24_lr=0.002_seed=123_v5'
model_path = '.'
model_type = "efficientnet-b5"
model_name = f"{model_type}__tile=f_fr={n_freeze}_is={image_size}_bs={batch_size}_ga={gradient_accumulation}_lr={lr}_seed={seed}_v=3"
pretrained_file_path = f"{pretrain_path}/weights_{model_weight}.pt"

if not os.path.exists(model_path):
    os.makedirs(model_path)

# save parameters        
logger = Logger(f"{model_path}/log.txt", verbose=True)
logger.append(model_name)
logger.append("\nParameters:\n  " + "\n  ".join([
    f"seed: {seed}", f"n_folds: {n_folds}", f"n_workers: {n_workers}",
    f"n_epochs: {n_epochs}", f"n_freeze: {n_freeze}", 
    f"batch_size: {batch_size}", f"image_size: {image_size}", f"lr: {lr}",
    f"bag size: {bag_size}", f"n_tta: {n_tta}",
]))

# frozen_layer_map = {
#     "efficientnet-b2": 299,
#     "efficientnet-b3": 338,
#     "efficientnet-b4": 416,
#     "efficientnet-b5": 504,
#     "efficientnet-b6": 582,
#     "efficientnet-b7": 709,
# }
class OwnSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, idx):
        self.idx = idx
    def __iter__(self):
        return iter(self.idx)
    def __len__(self):
        return len(self.idx)




In [None]:
kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

cv_targets = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')['diagnosis']
cv_preds = np.zeros(len(cv_targets))



cv_results = np.zeros((len(cv_targets), bag_size))

def pred_to_int(x):
    return np.round(np.clip(x, 0, 4))

for bag in range(bag_size):
    print(f"Bag {bag}")
    
    for fold, (train_index, valid_index) in enumerate(kf.split(train_dataset_2019, cv_targets)):
#         if fold > 0:
#             continue
        print()
        print(f"Fold {fold}")
        train_index = train_index
        valid_index = valid_index
        
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_index)
        valid_sampler = OwnSampler(valid_index)
        
        train_loader = torch.utils.data.DataLoader(train_dataset_2019, batch_size=batch_size, 
                                                   sampler=train_sampler, num_workers=n_workers)
       

        model = EfficientNet.from_pretrained(model_type, num_classes=1) 
        for i, param in enumerate(model.parameters()):
            if i < 502:
                param.requires_grad = False
        model = model.cuda()
    
        criterion = torch.nn.MSELoss()
        
        lr = 0.001
        #optimizer = adabound.AdaBound(model.parameters(), lr=lr, final_lr=0.1, eps=2e-4, weight_decay=1e-5)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=2e-4, weight_decay=1e-5)
        
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs, eta_min=0, last_epoch=-1)
        scheduler = CyclicLR(optimizer, step_size=len(train_loader)*n_epochs/2, base_lr=0.00001, max_lr=0.0005)
                
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) 
        
        if parallel:
            model = torch.nn.DataParallel(model).cuda()
        
        model.load_state_dict(torch.load(pretrained_file_path))
        
        early_stopping_count = 0
        best_kappa = -np.Inf
        for epoch in range(n_epochs):
            
            if epoch == n_freeze:
                for i, param in enumerate(model.parameters()):
                    param.requires_grad = True 
                del param
            torch.cuda.empty_cache()
            
            start_time = time.time()
            
            model.train() 
            avg_loss = 0.0
            optimizer.zero_grad()
            all_preds = []
            all_targets = []
            for idx, data in enumerate(train_loader):
                images = data['image']
                labels = data['diag1'].float()
                preds = model(images.cuda())
                loss = criterion(preds.reshape(-1), labels.float().cuda())
                avg_loss += loss.item() / len(train_loader)
                # gradient accumulation
                loss = loss / gradient_accumulation
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if (idx+1) % gradient_accumulation == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
                    optimizer.step() 
                    optimizer.zero_grad() 
                    
                all_preds += list(preds.detach().cpu().numpy())
                all_targets += list(labels.detach().cpu().numpy())
                
                scheduler.batch_step()
                
            all_preds = np.array(all_preds)
            all_targets = np.array(all_targets)
            


            kappa_train = (metrics.cohen_kappa_score(all_targets, pred_to_int(all_preds), weights='quadratic'))
            mse_train = (metrics.mean_squared_error(all_targets, all_preds))   
                    
            model.eval()
            avg_val_loss = 0.0
            
            all_preds = np.zeros(len(valid_index))
            all_targets = np.zeros(len(valid_index))
            with torch.no_grad():
                
                for i in range(n_tta):
                    val_dataset_2019.transform = aug_image(is_infer=True, transform_type=i)
                    valid_loader = torch.utils.data.DataLoader(val_dataset_2019, batch_size=batch_size, shuffle=False,
                                                  sampler=valid_sampler, num_workers=n_workers)
                    
                    for idx, data in enumerate(valid_loader):
                        images = data['image']
                        labels = data['diag1'].float()
                        preds = model(images.cuda())
                        loss = criterion(preds.reshape(-1), labels.float().cuda())
                        avg_val_loss += loss.item() / len(valid_loader) / n_tta
                        
                        all_preds[idx * batch_size:(idx + 1) * batch_size] += preds.detach().cpu().numpy().reshape(-1) / n_tta
                        all_targets[idx * batch_size:(idx + 1) * batch_size] = labels.detach().cpu().numpy()

            kappa = (metrics.cohen_kappa_score(all_targets, pred_to_int(all_preds), weights='quadratic'))
            mse = (metrics.mean_squared_error(all_targets, all_preds))         
                    
            elapsed_time = time.time() - start_time 
        
            # log results
            logger.append(
                f"{epoch:>2d}: time={elapsed_time:0.2f}s  "
                f"lr={scheduler.get_lr()[0]:<8.3g}"
                f"loss={avg_loss:0.5f}  "
                f"kappa={kappa_train:0.5f}  "  
                f"mse={mse_train:0.5f}  "  
                f"val_loss={avg_val_loss:0.5f}  "   
                f"val_kappa={kappa:0.5f}  "  
                f"val_mse={mse:0.5f}  "  
            )

            #scheduler.step()

            # check for improvement
            #if kappa > best_kappa:
            best_preds = all_preds
            early_stopping_count = 0
            best_kappa = kappa
            torch.save(model.state_dict(), f"{model_path}/weights_{bag}_{fold}.pt")
#             else:
#                 early_stopping_count += 1
#                 if early_stopping_count == patience:
#                     break
                    
        cv_preds[valid_index] += best_preds / bag_size
        print(best_kappa)
        print("Bag kappa:", metrics.cohen_kappa_score(all_targets, pred_to_int(cv_preds[valid_index]), weights='quadratic'))
        #break
        
        del model
        del train_loader
        del valid_loader
        del images
        del labels
        torch.cuda.empty_cache()
        gc.collect()
            



In [None]:
print("Bag kappa:", metrics.cohen_kappa_score(cv_targets, pred_to_int(cv_preds), weights='quadratic'))

In [None]:
np.save(f"{model_path}/cv_preds", cv_preds)

In [None]:
test_dataset_2019 = RetinopathyDataset(transform=aug_image(is_infer=True), ds="test19", preload=True)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(test_dataset_2019[120]["image"].permute(1, 2, 0)  )

In [None]:
models = []

for i, model_name in enumerate(os.listdir(model_path)):
    
    if "pt" not in model_name:
        continue
    
    model = EfficientNet.from_name(model_type)
    model._fc = nn.Linear(model._fc.in_features, 1)
    if parallel:
        model = torch.nn.DataParallel(model).cuda()
        
    model.load_state_dict(torch.load(f"{model_path}/{model_name}"))
    model = model.cuda()
    
    for param in model.parameters():
        param.requires_grad = False

    model.eval()
    models.append(model)

print(len(models))
print("Done")

In [None]:
test_preds = np.zeros(len(test_dataset_2019))
with torch.no_grad():

    for i in range(n_tta):
        test_dataset_2019.transform = aug_image(is_infer=True, transform_type=i)
        test_loader = torch.utils.data.DataLoader(test_dataset_2019, batch_size=batch_size, shuffle=False, num_workers=n_workers)

        for idx, data in enumerate(tqdm(test_loader)):
            images = data['image']
            labels = data['diag1'].float()
            
            for model in models:
                preds = model(images.cuda())
                test_preds[idx * batch_size:(idx + 1) * batch_size] += preds.detach().cpu().numpy().reshape(-1) / n_tta / len(models)

In [None]:
np.save(f"{model_path}/test_preds", test_preds)

In [None]:
test_preds = pred_to_int(test_preds)

In [None]:
submission = pd.read_csv('../input/aptos2019-blindness-detection/sample_submission.csv')
submission['diagnosis'] = test_preds
submission.to_csv('submission.csv', index=False)

In [None]:
import seaborn as sns
%matplotlib inline

sns.distplot(test_preds, kde=False)

In [None]:
submission['diagnosis'].value_counts()

In [None]:
!ls