In [3]:
import numpy as np
import keras
import glob
import os
import pandas as pd
import cv2
from keras.preprocessing.image import ImageDataGenerator


class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self,
                 base_dataset_path,
                 is_validation = False,
                 img_res=(224, 224),
                 batch_size=12,
                 shuffle=True):
        self.base_dataset_path = base_dataset_path
        
        end_point = 'validate_images/' if is_validation else 'train_images/'

        self.images_path = os.path.join(self.base_dataset_path, end_point)
        self.image_names = os.listdir(self.images_path)
        
        labels_path = os.path.join(self.base_dataset_path, 'train.csv')
        self.data_frame = pd.read_csv(labels_path)
        self.data_frame['diagnosis'] = self.data_frame['diagnosis'].astype('str')
        self.data_frame["id_code"] = self.data_frame["id_code"].apply(append_text)
        
        self.keras_generator = ImageDataGenerator(rotation_range=17,
                                                  brightness_range=[0.6, 1.2])
        
        self.is_validation = is_validation
        self.batch_size = batch_size
        self.img_res = img_res
        self.shuffle = True
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.floor(len(self.image_names) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        list_IDs_temp = [self.image_names[k] for k in indexes]
        X, y = self.__data_generation(list_IDs_temp)

        return X, y
    
    
    def get_random_crop(self, image, crop_height, crop_width):
        delta = 0#int(crop_height * 0.2)
        
        max_x = image.shape[1] - crop_width - delta 
        max_y = image.shape[0] - crop_height - delta 
        
        max_x = max(max_x, delta) + 1
        max_y = max(max_y, delta) + 1

        x = np.random.randint(delta, max_x)
        y = np.random.randint(delta, max_y)

        crop = image[y: y + crop_height, x: x + crop_width]

        return crop


    def on_epoch_end(self):
        self.indexes = np.arange(len(self.image_names))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        h = 224
        w = 224
            
        X = np.empty((self.batch_size, h, w, 3))
        y = np.empty((self.batch_size), dtype=int)

        for i, ID in enumerate(list_IDs_temp):
            img_path = os.path.join(self.images_path, ID)
            image = cv2.imread(img_path)
            coeff = 0.9
            
            new_height = int(image.shape[0] * coeff)
            new_width = int(image.shape[1] * coeff)
            image = self.get_random_crop(image, new_height, new_width)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (w, h))

            df_id_index = self.data_frame.index[self.data_frame['id_code'] == ID]
            y_value = self.data_frame['diagnosis'][df_id_index]
            
            if not self.is_validation:
                transform = self.keras_generator.get_random_transform((h, w, 3))
                image = self.keras_generator.apply_transform(image, transform)
                
            image = image / 255

            X[i,] = image

            y[i] = y_value

        return X, keras.utils.to_categorical(y, num_classes=5)
    
    
    def append_text(fn):
        return fn+".png"