In [None]:
import cv2
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset

import albumentations
from albumentations import pytorch as AT

IMAGE_SIZE = 512

rgby_mean = [0.08123, 0.05293, 0.05398, 0.08153]
rgby_std  = [0.13028, 0.08611, 0.14256, 0.12620]


train_transform = albumentations.Compose([
    albumentations.ToFloat(max_value=65535.0),
    albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE),
    albumentations.RandomRotate90(p=0.5),
    albumentations.Transpose(p=0.5),
    albumentations.Flip(p=0.5),    
    albumentations.OneOf([
        albumentations.ElasticTransform(alpha=1, sigma=20, alpha_affine=10),
        albumentations.GridDistortion(num_steps=6, distort_limit=0.1),
        albumentations.OpticalDistortion(distort_limit=0.05, shift_limit=0.05),
    ], p=0.2), 
    albumentations.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05, p=0.5),
    albumentations.core.composition.PerChannel(
        albumentations.OneOf([
            albumentations.MotionBlur(p=.05),
            albumentations.MedianBlur(blur_limit=3, p=.05),
            albumentations.Blur(blur_limit=3, p=.05),])
        , p=1.0),
    albumentations.OneOf([
        albumentations.CoarseDropout(max_holes=16, max_height=IMAGE_SIZE//16, max_width=IMAGE_SIZE//16, fill_value=0, p=0.5),
        albumentations.GridDropout(ratio=0.09, p=0.5),
        albumentations.Cutout(num_holes=8, max_h_size=IMAGE_SIZE//16, max_w_size=IMAGE_SIZE//16, p=0.2),
    ], p=0.5), 
    albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5),
    AT.ToTensorV2(),
    ],
    additional_targets={
        'r': 'image',
        'g': 'image',
        'b': 'image',
        'y': 'image',
    }
    )
    
    
test_transform = albumentations.Compose([
    albumentations.ToFloat(max_value=65535.0),
    albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE),
    AT.ToTensorV2(),
    ],
    additional_targets={
        'r': 'image',
        'g': 'image',
        'b': 'image',
        'y': 'image',
    }
    )
    
        
tta_transform = albumentations.Compose([
    albumentations.ToFloat(max_value=65535.0),
    albumentations.RandomRotate90(p=0.5),
    albumentations.Transpose(p=0.5),
    albumentations.Flip(p=0.5),
    albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE),
    AT.ToTensorV2(),
    ],
    additional_targets={
        'r': 'image',
        'g': 'image',
        'b': 'image',
        'y': 'image',
    }
    )
    

class ImageDataset(Dataset):
    
    def __init__(self, df, data_path='../input', transform = train_transform): 
        self.df = df 
        self.data_path = data_path
        self.imgs_path = self.df['folder']
        self.imgs_name = self.df['ID']
        self.labels = self.df['Label']
        self.transform = transform


    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        img_path = os.path.join(self.data_path, self.imgs_path[idx], self.imgs_name[idx])
        img_r = cv2.imread(img_path + '_red.png', -1)
        img_g = cv2.imread(img_path + '_green.png', -1)
        img_b = cv2.imread(img_path + '_blue.png', -1)
        img_rgb = np.dstack((img_r, img_g, img_b))
        img_y = cv2.imread(img_path + '_yellow.png', 3)
        # transform
        trans_img = self.transform(image=img_rgb, y=img_y)
        img_rgb, img_y = trans_img['image'], trans_img['y']
        img_rgby = np.concatenate([img_rgb, img_y], axis=0)[:4]
        # label
        onehot = np.zeros(18) # 0-17 18:neg
        label = self.labels[idx]
        if label != '18':
            label = list(set(label.split('|')))
            for i in label:
                onehot[int(i)] += 1
        label = onehot
        return img_rgby, label
        
        
class HPAMixup(Dataset):
    def __init__(self, dataset, num_class, max_mix=3, prob=0.1):
        self.dataset = dataset
        self.num_class = num_class
        self.max_mix = int(max_mix)
        self.prob = prob
        self.data_size = len(self)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        if self.max_mix <= 1:
            return img, label
        
        for i in range(self.max_mix-1):
            if torch.rand(1)[0] < self.prob:
                rand_idx = torch.randint(self.data_size,(1,))[0].numpy()
                img_aug, label_aug = self.dataset[rand_idx]
                img += img_aug
                label = label + label_aug 
                
        label = label > 0 # binary label
        label = label.astype(np.float)

        return img, label

    def __len__(self):
        return len(self.dataset)

In [None]:
# usage
# """
# trainset = ImageDataset(train_df.iloc[tr_folds[fold]].reset_index(), base_dir, train_transform)
# trainset_mixup = HPAMixup(trainset, CLASSES_NUM, max_mix=3, prob=0.1) 
# train_loader = torch.utils.data.DataLoader(trainset_mixup, batch_size=BATCH_SIZE, num_workers=16, shuffle=True, drop_last=True, worker_init_fn=worker_init_fn)
# """