This kernel shows an example for EDA and data augmentations inspired by following two kernels.
* https://www.kaggle.com/artgor/basic-eda-and-baseline-pytorch-model
* https://www.kaggle.com/abhishek/pytorch-inference-kernel-lazy-tta
* https://github.com/EthanRosenthal/spacecutter

In [None]:
import numpy as np
import pandas as pd
import torchvision
import torch.nn as nn
from tqdm import tqdm
from PIL import Image, ImageFile
from torch.utils.data import Dataset
import torch
from torchvision import transforms
import os
import matplotlib.pyplot as plt
import seaborn as sns
import collections.abc
from sklearn.model_selection import train_test_split
import re
from typing import Optional


device = torch.device("cuda:0")
ImageFile.LOAD_TRUNCATED_IMAGES = True
epochs = 15

In [None]:
from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose, RandomGamma, 
    ElasticTransform, ChannelShuffle,RGBShift, Rotate
)

First, let's try to visualize how the images look like within test/training set.

In [None]:

for i,path in enumerate(os.listdir('../input/aptos2019-blindness-detection/test_images')):
    img_path = os.path.join('../input/aptos2019-blindness-detection/test_images',path)
    im = Image.open(img_path,'r')
    ax = plt.subplot(3,3,i + 1)
    ax.imshow(im)
    if i == 8:
        break

In [None]:

for i,path in enumerate(os.listdir('../input/aptos2019-blindness-detection/train_images')):
    img_path = os.path.join('../input/aptos2019-blindness-detection/train_images',path)
    im = Image.open(img_path,'r')
    ax = plt.subplot(3,3,i + 1)
    ax.imshow(im)
    if i == 8:
        break

Okay, it looks like there is a great variety in shape and color in both of train/test dataset.
Therefore, we would like to investigate the effect of data augmentation. Especially, considering we have different sizes/aspect ratios, we have to crop the images no matter what.
One of the topics I would look into is the center crop vs random crop.

  


## Class balance
Also, let's see the class balance within the training set.
If there is a huge class imbalance, the likelihood is they are optimized to predict specific class(es), and we definetely would want to avoid that.

In [None]:
training_class = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')
sns.countplot(training_class['diagnosis'])

Not that terrible. but probably should consider correcting the balance still.
The strategy to correct the balance here is to simply apply data augmentation more on images of less frequent classes. 

In [None]:
def data_transforms(mode = 'random',img_size = 256):
    general_aug = Compose([
        OneOf([
            Transpose(),
            HorizontalFlip(),
            RandomRotate90()
            ]),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=.2),
        OneOf([
            OpticalDistortion(p=0.2),
            GridDistortion(distort_limit=0.2, p=.1),
            ElasticTransform(),
            ], p=1.)
        ], p=1)
    image_specific = Compose([
        OneOf([
            IAASharpen(),
            RandomContrast(),
            RandomBrightness(),
            ], p=0.3)
        ])
    all_transf_pre = [
            transforms.RandomCrop(round(1.2 * img_size))
            ]

    all_trans_after = [
            transforms.CenterCrop(img_size)
            ]
    center_crop = [
            transforms.CenterCrop(img_size)
    ]
    normalize = [
            transforms.ToTensor()
            ]

    def get_augment(aug):
        def augment(image):
            return Image.fromarray(aug(image=np.array(image))['image'])
        return [augment]

    def normalize_to_full_image(img):
        return img
        #img = np.array(img).astype(np.float32)
        #img -= img.min()
        #img /= img.max()
        #img *= 255
        #return img.astype(np.uint8)

    pre_crop = transforms.Compose(all_transf_pre) 
    train_img_transform = transforms.Compose(get_augment(general_aug) + get_augment(image_specific) + [normalize_to_full_image])
    norm_transform = transforms.Compose(all_trans_after + normalize)
    val_transform = transforms.Compose(all_trans_after) if mode == 'random' else transforms.Compose(center_crop)

    return pre_crop, train_img_transform, norm_transform, val_transform

For the augmentation, I adopted a variety of techniques, which can be categorized  as follows.
1. Shape transformation
This includes affine transformation, such as rotation, shifting and scaling, flipping and nonlinear transformation.    
2. Color transformation
This includes the change in brightness and contrast. 

In this kernel, I first resized the image to set the aspect ratio of each examples equal and then crops twice before and after applying any data augmentation to keep the black region resulting from data augmentation as small as possible.
Below are some examples of augmented samples

In [None]:
pre_crop, train_img_transform, _, center_crop = data_transforms()
fig = plt.figure()
for i,path in enumerate(os.listdir('../input/aptos2019-blindness-detection/train_images')):
    img_path = os.path.join('../input/aptos2019-blindness-detection/train_images',path)
    im = Image.open(img_path,'r')
    im = pre_crop(im.resize((320, 320), resample=Image.BILINEAR))
    ax = fig.add_subplot(4,2,2 * i + 1)
    ax.imshow(center_crop(im))
    ax.set_title('original')
    ax = fig.add_subplot(4,2,2 * i + 2)
    ax.set_title('augmented')
    im = train_img_transform(im)
    ax.imshow(center_crop(im))
    
    if i == 3:
        break
fig.suptitle('original vs augmented')
fig.tight_layout()

In [None]:
class RetinopathyDatasetTest(Dataset):
    def __init__(self, data,mode = 'test'):
        #self.data = pd.read_csv(csv_file)
        self.mode = mode
        self.data = data.reset_index()
        self.img_dir = '../input/aptos2019-blindness-detection/test_images' if mode == 'test' else '../input/aptos2019-blindness-detection/train_images' 
        _,_,_,self.transform = data_transforms('center') 
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data.loc[idx, 'id_code'] + '.png')
        image = Image.open(img_name)
        image = image.resize((320, 320), resample=Image.BILINEAR)
        image = self.transform(image)
        if self.mode == 'test':
            return {'image': transforms.ToTensor()(image)}
        else:
            return {'image': transforms.ToTensor()(image),'label': self.data.loc[idx,'diagnosis']}

In [None]:
#trainds = RetinopathyDatasetTrain('../input/aptos2019-blindness-detection/train.csv')
#valds = RetinopathyDatasetTest('../input/aptos2019-blindness-detection/train.csv')
train_df = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')
tr, val = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.05)
#train_sampler = SubsetRandomSampler(list(tr.index))
#val_sampler = SubsetRandomSampler(list(val.index))
# prepare data loaders (combine dataset and sampler)
#train_loader = torch.utils.data.DataLoader(trainds, batch_size=32, sampler=train_sampler, num_workers=4)
#val_loader = torch.utils.data.DataLoader(valds, batch_size=32, sampler=val_sampler, num_workers=4)


In [None]:
def this_collate_fn(batch):
    elem = batch[0]
    return {key:torch.cat([d[key] for d in batch],dim = 0) for key in elem} 

In [None]:
class RetinopathyDatasetTrain(Dataset):
    def __init__(self, data ,img_size = 224):
        #self.data = pd.read_csv(csv_file)
        self.data = data.reset_index()
        most_freq_class_num = len(self.data.query('diagnosis == 0'))
        self.aug_times = {str(diagnosis):np.round(most_freq_class_num / len(self.data.query('diagnosis == ' + str(diagnosis)))) for diagnosis in self.data['diagnosis'].unique()}
        
        #self.images = [Image.open(os.path.join('../input/aptos2019-blindness-detection/train_images',path),'r') for i,path in enumerate(os.listdir('../input/aptos2019-blindness-detection/train_images'))] 
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        self.pre_crop, self.train_img_transform, self.norm_transform, _ = data_transforms()
        img_name = os.path.join('../input/aptos2019-blindness-detection/train_images', self.data.loc[idx, 'id_code'] + '.png')
        image = Image.open(img_name)
        image = self.pre_crop(image.resize((320, 320), resample=Image.BILINEAR))
        key = str(self.data.loc[idx,'diagnosis'])
        aug_time = int(self.aug_times[key])
        img_list = [self.norm_transform(image)] + [self.norm_transform(self.train_img_transform(image)) for i in range(aug_time)]
        labels = [torch.tensor(self.data.loc[idx,'diagnosis'])] * len(img_list)
        return {'image': torch.stack(img_list),'label':torch.stack(labels,dim = 0)}

Since internet connection is disabled, I just copied the classes and libraries of  Ethan Rosenthal here https://github.com/EthanRosenthal/spacecutter

In [None]:
def _reduction(loss: torch.Tensor, reduction: str) -> torch.Tensor:
    """
    Reduce loss
    Parameters
    ----------
    loss : torch.Tensor, [batch_size, num_classes]
        Batch losses.
    reduction : str
        Method for reducing the loss. Options include 'elementwise_mean',
        'none', and 'sum'.
    Returns
    -------
    loss : torch.Tensor
        Reduced loss.
    """
    if reduction == 'elementwise_mean':
        return loss.mean()
    elif reduction == 'none':
        return loss
    elif reduction == 'sum':
        return loss.sum()
    else:
        raise ValueError(f'{reduction} is not a valid reduction')


def cumulative_link_loss(y_pred: torch.Tensor, y_true: torch.Tensor,
                         reduction: str = 'elementwise_mean',
                         class_weights: Optional[np.ndarray] = None
                         ) -> torch.Tensor:
    """
    Calculates the negative log likelihood using the logistic cumulative link
    function.
    See "On the consistency of ordinal regression methods", Pedregosa et. al.
    for more details. While this paper is not the first to introduce this, it
    is the only one that I could find that was easily readable outside of
    paywalls.
    Parameters
    ----------
    y_pred : torch.Tensor, [batch_size, num_classes]
        Predicted target class probabilities. float dtype.
    y_true : torch.Tensor, [batch_size, 1]
        True target classes. long dtype.
    reduction : str
        Method for reducing the loss. Options include 'elementwise_mean',
        'none', and 'sum'.
    class_weights : np.ndarray, [num_classes] optional (default=None)
        An array of weights for each class. If included, then for each sample,
        look up the true class and multiply that sample's loss by the weight in
        this array.
    Returns
    -------
    loss: torch.Tensor
    """
    eps = 1e-15
    likelihoods = torch.clamp(torch.gather(y_pred, 1, y_true), eps, 1 - eps)
    neg_log_likelihood = -torch.log(likelihoods)

    if class_weights is not None:
        # Make sure it's on the same device as neg_log_likelihood
        class_weights = torch.as_tensor(class_weights,
                                        dtype=neg_log_likelihood.dtype,
                                        device=neg_log_likelihood.device)
        neg_log_likelihood *= class_weights[y_true]

    loss = _reduction(neg_log_likelihood, reduction)
    return loss


class CumulativeLinkLoss(nn.Module):
    """
    Module form of cumulative_link_loss() loss function
    Parameters
    ----------
    reduction : str
        Method for reducing the loss. Options include 'elementwise_mean',
        'none', and 'sum'.
    class_weights : np.ndarray, [num_classes] optional (default=None)
        An array of weights for each class. If included, then for each sample,
        look up the true class and multiply that sample's loss by the weight in
        this array.
    """

    def __init__(self, reduction: str = 'elementwise_mean',
                 class_weights: Optional[torch.Tensor] = None) -> None:
        super().__init__()
        self.class_weights = class_weights
        self.reduction = reduction

    def forward(self, y_pred: torch.Tensor,
                y_true: torch.Tensor) -> torch.Tensor:
        return cumulative_link_loss(y_pred, y_true,
                                    reduction=self.reduction,
                                    class_weights=self.class_weights)

In [None]:
from copy import deepcopy

import torch
from torch import nn


class LogisticCumulativeLink(nn.Module):
    """
    Converts a single number to the proportional odds of belonging to a class.
    Parameters
    ----------
    num_classes : int
        Number of ordered classes to partition the odds into.
    init_cutpoints : str (default='ordered')
        How to initialize the cutpoints of the model. Valid values are
        - ordered : cutpoints are initialized to halfway between each class.
        - random : cutpoints are initialized with random values.
    """

    def __init__(self, num_classes: int,
                 init_cutpoints: str = 'ordered') -> None:
        assert num_classes > 2, (
            'Only use this model if you have 3 or more classes'
        )
        super().__init__()
        self.num_classes = num_classes
        self.init_cutpoints = init_cutpoints
        if init_cutpoints == 'ordered':
            num_cutpoints = self.num_classes - 1
            cutpoints = torch.arange(num_cutpoints).float() - num_cutpoints / 2
            self.cutpoints = nn.Parameter(cutpoints)
        elif init_cutpoints == 'random':
            cutpoints = torch.rand(self.num_classes - 1).sort()[0]
            self.cutpoints = nn.Parameter(cutpoints)
        else:
            raise ValueError(f'{init_cutpoints} is not a valid init_cutpoints '
                             f'type')

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """
        Equation (11) from
        "On the consistency of ordinal regression methods", Pedregosa et. al.
        """
        sigmoids = torch.sigmoid(self.cutpoints - X)
        link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]
        link_mat = torch.cat((
                sigmoids[:, [0]],
                link_mat,
                (1 - sigmoids[:, [-1]])
            ),
            dim=1
        )
        return link_mat


class OrdinalLogisticModel(nn.Module):
    """
    "Wrapper" model for outputting proportional odds of ordinal classes.
    Pass in any model that outputs a single prediction value, and this module
    will then pass that model through the LogisticCumulativeLink module.
    Parameters
    ----------
    predictor : nn.Module
        When called, must return a torch.FloatTensor with shape [batch_size, 1]
    init_cutpoints : str (default='ordered')
        How to initialize the cutpoints of the model. Valid values are
        - ordered : cutpoints are initialized to halfway between each class.
        - random : cutpoints are initialized with random values.
    """

    def __init__(self, predictor: nn.Module, num_classes: int,
                 init_cutpoints: str = 'ordered') -> None:
        super().__init__()
        self.num_classes = num_classes
        self.predictor = deepcopy(predictor)
        self.link = LogisticCumulativeLink(self.num_classes,
                                           init_cutpoints=init_cutpoints)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.link(self.predictor(X))


In [None]:
res_model = torchvision.models.resnet34(pretrained=False)
res_model.load_state_dict(torch.load("../input/resnet34/resnet34.pth"))
res_model.fc = nn.Linear(512, 1)
model =  OrdinalLogisticModel(res_model, 5)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = 1e-3)

In [None]:
train_ds = RetinopathyDatasetTrain(tr)
val_ds = RetinopathyDatasetTest(val,mode = 'val')
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32,collate_fn=this_collate_fn, shuffle = True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle = False, num_workers=4)



In [None]:
#test_preds = np.zeros((len(test_dataset), 1))
criterion = CumulativeLinkLoss()
for epoch in range(epochs):
    train_loss = []
    val_loss = []
    model.train()
    for i, x_batch in enumerate(train_loader):
        
        model.zero_grad()
        img = x_batch["image"]
        img = img.to(device).float()
        label = x_batch['label'].to(device).long().reshape(-1,1)
        output = model(img)
        loss = criterion(output,label)
        train_loss.append(loss)
        loss.backward()
        optimizer.step()   
    model.eval()
    with torch.no_grad():
        for j,x_batch in enumerate(val_loader):
            img = x_batch["image"]
            img = img.to(device).float()
            label = x_batch['label'].to(device).long().reshape(-1,1)
            output = model(img)
            loss = criterion(output,label)
            val_loss.append(loss)
    train_mean_loss = torch.mean(torch.stack(train_loss)).data.cpu().numpy()
    val_mean_loss = torch.mean(torch.stack(val_loss)).data.cpu().numpy()
        
    print(f'Epoch {epoch}, train loss: {train_mean_loss:.4f}, valid loss: {val_mean_loss:.4f}.')
    

In [None]:
for param in model.parameters():
    param.requires_grad = False

model.eval()

In [None]:
test_df = pd.read_csv('../input/aptos2019-blindness-detection/test.csv')
test_dataset = RetinopathyDatasetTest(test_df,mode = 'test')
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
test_preds = np.zeros((len(test_dataset), 1))

for i, x_batch in enumerate(test_data_loader):
    x_batch = x_batch["image"]
    _,pred = torch.max(model(x_batch.to(device)),1)
    test_preds[i * 32:(i + 1) * 32] = pred.cpu().squeeze().numpy().ravel().reshape(-1, 1)

In [None]:
sample = pd.read_csv("../input/aptos2019-blindness-detection/sample_submission.csv")
sample.diagnosis = test_preds.astype(int)
sample.to_csv("submission.csv", index=False)

* To do; improve some data augmentations to avoid obviously wrong ones
