# Purpose
In order to efficiently train a network, a good data generator is needed. One that does not exceed ram limit, is efficient and fast.
Here I present several data generating classes along with some helper methods to achieve that. 



In [None]:
TRAIN_IMG_IDS = ['095bf7a1f', 'afa5e8098', 'e79de561c', '4ef6695ce', '0486052bb', 'b9a3865fc', 'aaa6a05cc', '2f6ecfcdf',
                 '54f2eec69', 'c68fe75ea', '8242609fa', '1e2425f28', 'b2dc8411c', 'cb2d976f4', '26dc41664']

TEST_IMG_IDS = ['aa05346ff', '2ec3f1bb9', '3589adb90', 'd488c759a', '57512b7f1']

DATA_DIR = "/kaggle/input/hubmap-kidney-segmentation"


import json
import os
import random

import tifffile as tiff
import cv2
from PIL import Image
from keras_preprocessing.image import load_img
from matplotlib.patches import Rectangle
from tensorflow import keras
import pandas as pd

import numpy as np
from matplotlib import pyplot as plt

#%%


def scale_down(image, scale):
    nx = (image.shape[0] // scale) * scale
    ny = (image.shape[1] // scale) * scale
    return cv2.resize(image[:nx, :ny], (ny // scale, nx // scale))


def check_if_in_dataset(filename):
    if filename in TRAIN_IMG_IDS:
        return "train"
    elif filename in TEST_IMG_IDS:
        return "test"
    else:
        return ""


def show_hsv(im):
    hsv_img = cv2.cvtColor(im, cv2.COLOR_RGB2HSV)
    l = ["H", "S", "V"]
    plt.figure()
    for i in range(3):
        count, bins = np.histogram(hsv_img[:,:,i], bins=np.arange(hsv_img.max()))
        plt.plot( bins[1:-1], count[1:], label=l[i])
    plt.legend()
    plt.show()


def get_image(file_name, dataset, initial_scale, PATH=DATA_DIR):
    image = tiff.imread('{}/{}/{}.tiff'.format(PATH,dataset, file_name))
    dim = image.ndim

    if dim == 5:
        image = image[0, 0, :, :, :].transpose((1, 2, 0))
    elif image.shape[0] == 3:
        image = image.transpose((1, 2, 0))

    return scale_down(image, initial_scale)


def get_pink_mask(image, mask_scale, blur_kernel):
    rgb_img = scale_down(image, mask_scale)
    hsv_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2HSV)
    light_pink = (120, 10, 20)
    dark_pink = (150, 190, 255)
    mask_pink = cv2.inRange(hsv_img, light_pink, dark_pink)
    blured = cv2.blur(mask_pink,blur_kernel)
    return (blured > 127).astype(np.uint8)


def get_glomerous_polygon_list(file_name):

    glomerulus = open('{}/{}/{}.json'.format(DATA_DIR,"train", file_name))
    json_obj = json.load(glomerulus)

    polygon_list = []

    for elem in json_obj:
        narr = np.array(elem['geometry']['coordinates'][0])
        polygon_list.append(narr)

    return polygon_list

def get_glomerulus_mask(filename, initial_scale, PATH=DATA_DIR):
    df = pd.read_csv(PATH + "/HuBMAP-20-dataset_information.csv")
    w = df[df.image_file == (filename + ".tiff")].width_pixels.tolist()[0]
    h = df[df.image_file == (filename + ".tiff")].height_pixels.tolist()[0]
    image_shape = (h,w)
    mask = np.zeros(image_shape)
    polygon_list = get_glomerous_polygon_list(filename)
    color = 1
    for coordinates in polygon_list:
        cv2.fillPoly(mask, pts=[coordinates.astype(np.int32)], color=color)
    nx = (mask.shape[0] // initial_scale) * initial_scale
    ny = (mask.shape[1] // initial_scale) * initial_scale
    return cv2.resize(mask[:nx, :ny], (ny // initial_scale, nx // initial_scale))

#%%

def get_mask_grid_sum(mask, mask_scale, grid_size, x, y):
    x_from = x // mask_scale
    y_from = y // mask_scale
    x_to = (x + grid_size) // mask_scale
    y_to = (y + grid_size) // mask_scale
    return mask[x_from:x_to,y_from:y_to].sum()


def get_grids(image, mask, initial_scale, GRID_SIZE, OVERLAP, include_empty_grids=False):
    mask_scale = 10
    mask = scale_down(mask, mask_scale)
    # mask_threshold determined by roughly single glomerulus size
    mask_threshold = (400/initial_scale/mask_scale)**2
    # blur range tested empiricaly to be about half of glomerulus size
    blur_range = max(200 // initial_scale // mask_scale, 3)
    img_shape = image.shape
    pink_mask = get_pink_mask(image, mask_scale, (blur_range, blur_range))
    grids = []
    for row in range(0, img_shape[0] - GRID_SIZE, GRID_SIZE - OVERLAP):
        for col in range(0, img_shape[1] - GRID_SIZE, GRID_SIZE - OVERLAP):
            valid_pixels = get_mask_grid_sum(pink_mask,mask_scale,GRID_SIZE, row, col)
            if valid_pixels > mask_threshold:
                glomeruli_pixels = get_mask_grid_sum(mask,mask_scale,GRID_SIZE, row, col)
                if glomeruli_pixels > 0 or include_empty_grids:
                    grids.append([row,col,row+GRID_SIZE, col+GRID_SIZE])
    return grids


def get_BB(coords):
    ret = np.zeros((4,))
    # ret: [min_x,min_y,max_x,max_y]
    ret[:2] = np.min(coords, axis=0)[::-1]
    ret[2:] = np.max(coords, axis=0)[::-1]
    return ret


def in_grid(x, GRID_SQUARE_SIZE):
    if x[0] >= GRID_SQUARE_SIZE or x[2] <= 0:
        return False
    if x[1] >= GRID_SQUARE_SIZE or x[3] <= 0:
        return False
    return True


def turnicate(bb, GRID_SQUARE_SIZE):
    bb[bb < 0] = 0
    bb[bb > (GRID_SQUARE_SIZE - 1)] = GRID_SQUARE_SIZE - 1
    return bb


def get_BB_within(polys, x_start, y_start, GRID_SIZE, initial_scale):
    GRID_SIZE = GRID_SIZE*initial_scale
    x_start = x_start*initial_scale
    y_start = y_start*initial_scale
    offset = np.array([x_start, y_start, x_start, y_start])
    ofs = [get_BB(coords) - offset for coords in polys]
    relative = [turnicate(x, GRID_SIZE) for x in ofs if in_grid(x, GRID_SIZE)]
    return [x//initial_scale for x in relative if in_grid(x, GRID_SIZE)]


#%%

#%%

def draw_img_boxes(img, boxes):
    plt.figure()
    plt.imshow(img)
    ax = plt.gca()
    mask = np.zeros(img.shape[:2])
    for coords in boxes:
        lx, ly, rx, ry = coords.astype(np.int)
        mask[lx:rx,ly:ry] += 1
    plt.imshow(mask, alpha=0.2)
    for coords in boxes:
        # flip coords for Rectangle
        y,x = coords[0:2]
        dy,dx = coords[2] - coords[0], coords[3] - coords[1]
        rect = Rectangle((x,y),dx, dy,linewidth=2,edgecolor='w',facecolor='none')
        ax.add_patch(rect)

#%%

def showcase_bbs(file_name, grid_i):
    im = get_image("afa5e8098", "train", 2)
    mask = get_glomerulus_mask("afa5e8098", 2)
    g = get_grids(im, mask, 2, 1024, 256)
    lx, ly, rx, ry = g[grid_i]
    polygon_list = get_glomerous_polygon_list(file_name)
    bbs = get_BB_within(polygon_list, lx, ly, 1024, 2)
    draw_img_boxes(im[lx:rx,ly:ry], bbs)
#%%
# uncomment to see img with bounding boxes
# showcase_bbs("afa5e8098", 43)

#%%

def showcase_grid(file_name):
    im = get_image(file_name, "train", 10)
    mask = get_glomerulus_mask(file_name, 10)
    g = get_grids(im,mask, 10, 256,64)

    im = scale_down(im, 10)
    mask = scale_down(mask, 10)

    gm = np.zeros((im.shape[0]*10, im.shape[1]*10))
    for lx,ly,rx,ry in g:
        gm[lx:rx,ly:ry] += 1
    gm = cv2.resize(gm, (gm.shape[1]//10,gm.shape[0]//10))
    plt.figure()
    plt.imshow(im)
    plt.imshow(mask, alpha=0.3)
    plt.imshow(gm, alpha=0.3)
#%%
# uncomment to see grid
# showcase_grid("afa5e8098")

#%%


class MyDataSequenceSlow(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, train_img_names, initial_scale=2, GRID_SIZE=1024, OVERLAP=256):
        self.batch_size = batch_size
        self.batches = []
        self.initial_scale = initial_scale
        print("starting preprocessing")
        for file_name in train_img_names:
            print("Preparing", file_name, "...")
            image = get_image(file_name, "train", initial_scale)
            mask = get_glomerulus_mask(file_name, initial_scale)
            grids = get_grids(image, mask, file_name, initial_scale, GRID_SIZE, OVERLAP)
            random.shuffle(grids)
            print("Prepared", len(grids), "grids.", len(grids)%batch_size, "lost in batching.")
            for i in range(0, len(grids) - batch_size, batch_size):
                self.batches.append((file_name, grids[i:i+batch_size]))
            print("Batches saved.")

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        raise Exception("Not implemented.")


class MyDataSequenceSlowMask(MyDataSequenceSlow):

    def __getitem__(self, idx):
        file_name, grids = self.batches[idx]

        image = get_image(file_name, "train", self.initial_scale)
        mask = get_glomerulus_mask(file_name, self.initial_scale)

        grid_imgs = []
        grid_masks = []
        for lx,ly,rx,ry in grids:
            im = image[lx:rx,ly:ry]
            grid_imgs.append(im)
            m = mask[lx:rx,ly:ry]
            grid_masks.append(m)
        return np.array(grid_imgs), np.array(grid_masks)


class MyDataSequenceSaving(keras.utils.Sequence):

    def __init__(self, batch_size, train_img_names, initial_scale=2, GRID_SIZE=1024, OVERLAP=256,preprocess=True):
        self.CACHE_PATH = "/kaggle/working/cache"
        self.TRAIN_PATH = self.CACHE_PATH + "/train"
        self.IMAGES = self.TRAIN_PATH + "/images"
        self.MASKS = self.TRAIN_PATH + "/masks"
        self.GRID_SIZE = GRID_SIZE
        self.batch_size = batch_size
        self.img_list = []
        self.initial_scale = initial_scale
        if preprocess:
            print("starting preprocessing")
            try:
                os.makedirs(self.IMAGES)
                os.makedirs(self.MASKS)
            except:
                print("dirs already exist")

            for file_name in train_img_names:
                bb_dict = {file_name: []}
                print("Preparing", file_name, "...")
                image = get_image(file_name, "train", initial_scale)
                mask = get_glomerulus_mask(file_name, initial_scale)
                grids = get_grids(image, mask, initial_scale, GRID_SIZE, OVERLAP)
                polygon_list = get_glomerous_polygon_list(file_name)
                print("Prepared", len(grids), "grids.")
                for i,(lx,ly,rx,ry) in enumerate(grids):
                    im = Image.fromarray(image[lx:rx, ly:ry])
                    im.save(f"{self.IMAGES}/{file_name}_{i}.png")

                    msk = Image.fromarray(mask[lx:rx, ly:ry].astype(np.uint8))
                    msk.save(f"{self.MASKS}/{file_name}_{i}.png")

                    bbs = get_BB_within(polygon_list, lx,ly,GRID_SIZE, initial_scale)
                    bbs_lists = [x.tolist() for x in bbs]
                    bb_dict[file_name].append(bbs_lists)

                    self.img_list.append(file_name + "_" + str(i))
                with open(f'{self.TRAIN_PATH}/bbs-{file_name}.json', 'w') as fp:
                    json.dump(bb_dict, fp)
                print("pics saved.")
        else:
            for dirname, _, filenames in os.walk(self.IMAGES):
                for filename in filenames:
                    self.img_list.append(filename[:-4])
        random.shuffle(self.img_list)

    def __len__(self):
        return len(self.img_list)//self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        raise Exception("Not implemented.")


class MyDataSequenceSavingMask(MyDataSequenceSaving):

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""

        batch = self.img_list[idx*self.batch_size:(idx+1)*self.batch_size]
        grid_imgs = []
        grid_masks = []

        for img_name in batch:
            img_path = f"{self.IMAGES}/{img_name}.png"
            mask_path = f"{self.MASKS}/{img_name}.png"
            image = load_img(img_path, target_size=(self.GRID_SIZE, self.GRID_SIZE))
            mask = load_img(mask_path, target_size=(self.GRID_SIZE, self.GRID_SIZE), color_mode="grayscale")

            grid_imgs.append(np.array(image))
            grid_masks.append(np.expand_dims(mask, axis=2))

        return np.array(grid_imgs), np.array(grid_masks)

class MyDataSequenceSavingBBs(MyDataSequenceSaving):

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""

        batch = self.img_list[idx*self.batch_size:(idx+1)*self.batch_size]
        grid_imgs = []
        grid_BBs = []

        for img_name in batch:
            file_name,img_idx = img_name.split("_")
            img_path = f"{self.IMAGES}/{img_name}.png"
            bbs_path = f"{self.TRAIN_PATH}/bbs-{file_name}.json"
            image = load_img(img_path, target_size=(self.GRID_SIZE, self.GRID_SIZE))
            with open(bbs_path) as json_file:
                bbs = json.load(json_file)
            grid_imgs.append(np.array(image))
            bbs_lists = bbs[file_name][int(img_idx)]
            bbs_arrays = [np.array(x) for x in bbs_lists]
            grid_BBs.append(bbs_arrays)

        return np.array(grid_imgs), grid_BBs

In [None]:
showcase_grid("afa5e8098")

In [None]:
showcase_bbs("afa5e8098", 43)

In [None]:
train_gen = MyDataSequenceSavingBBs(3, TRAIN_IMG_IDS[:1])

In [None]:
batchx,batchy = train_gen[0]
draw_img_boxes(batchx[1],batchy[1])