In [None]:
from __future__ import (absolute_import, division,
                        print_function, unicode_literals)
from builtins import *

In [None]:
 import sys
 print(sys.executable)
 print(sys.version)
 print(sys.version_info)

# Input pipeline

In [None]:
import tensorflow as tf
#run_opts = tf.RunOptions(report_tensor_allocations_upon_oom = True)

# Data extraction
#DATA_PATH = "/home/jupyter/prostate_ct_small"
DATA_PATH = "/home/matthew/priv/PROSTATE_TEST/"
STRUCTURE_NAMES = ["BLADDER"]

# Train/Valid/Test
SPLIT_RATIO = (0.7, 0.2, 0.1)

# DATA SHAPE
CONTEXT = 1
BATCH_SIZE = 1
SIZE = 128
INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)
OUTPUT_SHAPE = (1, SIZE, SIZE, len(STRUCTURE_NAMES))


# MODEL COMPILING
EPOCHS = 1
OPTIMIZER = 'adam'
LOSS = tf.nn.sigmoid_cross_entropy_with_logits
METRICS = ['accuracy']






In [None]:
import glob
import random

def get_paths(data_path, context):
    patient_paths = glob.glob(data_path + "/*")

    context_paths = glob.glob(data_path + "/*/*CT*", recursive=True)
    context_paths.sort()
    input_paths = [
            glob.glob(path + "/*CT*")[context:-context] for path in patient_paths
        ]
    input_paths = [item for sublist in input_paths for item in sublist]
    random.shuffle(input_paths)

    label_paths = glob.glob(data_path + "/*/*RS*", recursive=True)

    #assert len(context_paths) - (len(label_paths) * 2 *
    #                             context) == len(input_paths)

    return input_paths, context_paths, label_paths


def split_paths(input_paths, ratio):
    num = len(input_paths)
    num_train = int(num * ratio[0] // 1)
    num_valid = int(num * ratio[1] // 1)
    num_test = int(num * ratio[2] // 1)

    #print(f"Total: {num} = Train: {num_train} + Valid: {num_valid} + Test: {num_test}")
    print("Total:", num)
    print("==========")
    print("Train:", num_train)
    print("Valid:", num_valid)
    print("Test:", num_test)

    train_paths = input_paths[0:num_train]
    valid_paths = input_paths[num_train:num_train + num_valid]
    test_paths = input_paths[num_train + num_valid:]
    return train_paths, valid_paths, test_paths

In [None]:
train_paths, valid_paths, test_paths = split_paths(input_paths, SPLIT_RATIO)

print("----------")
print("Batch size", BATCH_SIZE)
print("Steps per epoch (# batches per epoch)", len(train_paths) // BATCH_SIZE)


In [None]:
# %load generator
import tensorflow as tf
import functools
import skimage.draw
import numpy as np
# Perhaps use tf.io instead
import pydicom
from pathlib import Path
import skimage.transform

# TODO - Modulate get_item

class DataGen(tf.keras.utils.Sequence):
    def __init__(self, input_paths, context_paths, label_paths, context,
                 batch_size, structure_names, resize):
        self.input_paths = input_paths
        self.context_paths = context_paths
        self.label_paths = label_paths
        self.context = context
        self.batch_size = batch_size
        self.structure_names = structure_names
        self.resize = resize

        for path in self.label_paths:
            _ = self.pre_cached_structures(path)

        self.on_epoch_end()

    @functools.lru_cache()
    def pre_cached_structures(self, path):
        return pydicom.dcmread(path, force=True)

    def get_parent_dir(self, path):
        return Path(path).parent.name

    # TODO
    # def resize_vol(volume, shape):
    #     for s in volume:
    #         skimage.transform.resize(s, shape)
    #     return volume
    
    # TODO
    # def normal:


    def __getitem__(self, batch_index):

        if (batch_index + 1) * self.batch_size > len(self.input_paths):
            self.batch_size = len(
                self.input_paths) - batch_index * self.batch_size

        batch_paths = self.input_paths[batch_index *
                                       self.batch_size:(batch_index + 1) *
                                       self.batch_size]

        batch_inputs = []
        batch_labels = []

        for image_path in batch_paths:
            # Get parent dir
            parent_dir = self.get_parent_dir(image_path)
            # Get mask path
            mask_path = [s for s in self.label_paths if parent_dir in s][0]
            # Get index
            image_index = self.context_paths.index(image_path)
            # Get context
            input_paths = self.context_paths[image_index -
                                             self.context:image_index +
                                             self.context + 1]

            try:
                assert len(input_paths) == 2 * self.context + 1
            except:
                continue

            ###################### IMAGE LOOP ###################################

            images = []
            for dcm_path in input_paths:
                dicom_ct = pydicom.dcmread(dcm_path, force=True)
                try:
                    dicom_ct.file_meta.TransferSyntaxUID
                except AttributeError:
                    dicom_ct.file_meta.TransferSyntaxUID = (
                        pydicom.uid.ImplicitVRLittleEndian)
                image = dicom_ct.pixel_array
                image = skimage.transform.resize(image, (self.resize, self.resize))
                images = images + [image]

            batch_inputs.append(images)

            ####################### MASK LOOP ####################################

            img = pydicom.dcmread(image_path, force=True)
            img_position = img.ImagePositionPatient
            img_spacing = [x for x in img.PixelSpacing] + [img.SliceThickness]
            img_orientation = img.ImageOrientationPatient

            dicom_structures = self.pre_cached_structures(mask_path)

            assert img.FrameOfReferenceUID == dicom_structures.StructureSetROISequence[
                0].ReferencedFrameOfReferenceUID

            dcm_rs_struct_names = [
                structure.ROIName
                for structure in dicom_structures.StructureSetROISequence
            ]

            structure_names = self.structure_names

            names_to_pull = [
                name for name in dcm_rs_struct_names if name in structure_names
            ]
            try:
                assert len(names_to_pull) == len(structure_names)
            except:
                batch_inputs.pop()
                continue

            structure_indexes = [
                dcm_rs_struct_names.index(name) for name in names_to_pull
            ]

            mask = np.zeros(shape=(1, 512, 512, len(structure_indexes)))

            dx, dy, *rest = img_spacing
            Cx, Cy, *rest = img_position
            Ox, Oy = img_orientation[0], img_orientation[4]

            for mask_index, structure_index in enumerate(structure_indexes):
                z = [
                    z_slice.ContourData[2::3][0]
                    for z_slice in dicom_structures.
                    ROIContourSequence[structure_index].ContourSequence
                ]

                try:
                    indexes = z.index(img_position[2])
                except:
                    continue

                try:
                    len(indexes)
                except:
                    indexes = [indexes]

                for index in indexes:
                    xyz = dicom_structures.ROIContourSequence[
                        structure_index].ContourSequence[index].ContourData

                    x = np.array(xyz[0::3])
                    y = np.array(xyz[1::3])

                    r = (y - Cy) / dy * Oy
                    c = (x - Cx) / dx * Ox

                    rr, cc = skimage.draw.polygon(r, c)

                    mask[:, rr, cc, mask_index] = True
            
            mask = skimage.transform.resize(mask, (1, self.resize, self.resize, len(structure_indexes)))
                                            
            batch_labels.append(mask)

        ###################### RETURNS ###################################
        batch_inputs = np.array(batch_inputs)
        batch_inputs = batch_inputs[..., np.newaxis]

        #batch_input = np.array(batch_inputs)
        batch_labels = np.array(batch_labels)

        return batch_inputs, batch_labels

    def __len__(self):
        # number of batches per epoch
        return int(np.ceil(len(self.input_paths) / float(self.batch_size)))

    def on_epoch_end(self):
        """Updates indexes after each epoch
        """
        None


In [None]:
train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
# Get a random batch

from random import randint
batch_index = randint(0, round(len(test_paths) / BATCH_SIZE)-1)
print("Batch index:", batch_index, "/", round(len(test_paths) / BATCH_SIZE)-1)
test_inputs, test_labels = test_gen.__getitem__(batch_index=batch_index)

print("inputs:", test_inputs.shape)
print("labels:", test_labels.shape)

In [None]:
type(test_inputs)

In [None]:
# Plot an random example input and label from batch

import matplotlib.pyplot as plt
index = randint(0, test_inputs.shape[0]-1)
print("index:", index, "/", test_inputs.shape[0]-1)


fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5), sharex=True, sharey=True)
axes[0].imshow(test_inputs[index,0,:,:,0])
axes[1].imshow(test_labels[index,0,:,:,0])

fig.tight_layout()

In [None]:
# ALL THE ABOVE SHOULD BE WORKING!

# Test Model building

In [None]:
def test_model(input_shape, output_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Conv3D(1, (3,3,3), padding='same')(inputs)
    x = tf.keras.layers.AveragePooling3D(pool_size=(11, 1, 1),
                                              strides=1,
                                              padding='valid')(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv3D(1, (3,3,3), padding='same')(x)
    x = tf.keras.layers.AveragePooling3D(pool_size=(11, 1, 1),
                                              strides=1,
                                              padding='valid')(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv3D(output_shape[-1], 1, activation=None)(x)
    x = tf.keras.activations.sigmoid(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# DATA SHAPE
CONTEXT = 10
BATCH_SIZE = 50
SIZE = 128

INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)
OUTPUT_SHAPE = (1, SIZE, SIZE, len(STRUCTURE_NAMES))

train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
model = test_model(INPUT_SHAPE, OUTPUT_SHAPE)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.summary()

In [None]:
model.fit_generator(generator=train_gen,
                    validation_data=valid_gen,
                    steps_per_epoch=len(train_paths) // BATCH_SIZE,
                    epochs=EPOCHS)

In [None]:
# Get a random testing batch for predictions

from random import randint
batch_index = randint(0, round(len(test_paths) / BATCH_SIZE)-1)
print("Batch index:", batch_index, "/", round(len(test_paths) / BATCH_SIZE)-1)
test_inputs, test_labels = test_gen.__getitem__(batch_index=batch_index)

print("inputs:", test_inputs.shape)
print("labels:", test_labels.shape)

In [None]:
predicts = model.predict(test_inputs)
predicts.shape

In [None]:
# Plot an random example input and label from batch

import matplotlib.pyplot as plt
index = randint(0, test_inputs.shape[0]-1)
print("index:", index, "/", test_inputs.shape[0]-1)

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5), sharex=True, sharey=True)
axes[0].imshow(test_inputs[index,0,:,:,0])
axes[1].imshow(test_labels[index,0,:,:,0])
axes[2].imshow(predicts[index,0,:,:,0])
fig.tight_layout()

# UNET FUNCTS

In [None]:
import tensorflow as tf

def down_block(x, m, n, c, size):
    Zcrop = int((n / 2) * (size - 1))
    crop = tf.keras.layers.Cropping3D(cropping=(Zcrop, 0, 0))(x)
    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(crop)

    result = tf.keras.layers.ReLU()(x)
    result = x
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same')(result)
        result = tf.keras.layers.ReLU()(result)

    for repeat in range(n):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same', activation=None)(result)
        result = tf.keras.layers.Conv3D(c, (size, 1, 1),
                                        strides=1,
                                        padding='valid', activation=None)(result)
        
        if repeat != range(n)[-1]:
            result = tf.keras.layers.ReLU()(result)

    result = tf.keras.layers.Add()([crop, result])

    return result


def pool(x, size_xy):
    result = tf.keras.layers.AveragePooling3D(pool_size=(1, size_xy, size_xy),
                                              strides=None,
                                              padding='valid')(x)
    return result


def fc_block(x, r, inp = 1024, out=256):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.layers.Conv3D(inp, (1, 8, 8),
                                    strides=1,
                                    padding='valid')(x)
    for repeat in range(r):
        crop = result
        # TODO: Should this be a dense layer with RelU activation instead?
        result = tf.keras.layers.ReLU()(result)
        result = tf.keras.layers.Add()([crop, result])

    result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Reshape((1, 8, 8, out))(x)

    return result


def up_block(x, m, c):
    initializer = tf.random_normal_initializer(0., 0.02)

    crop = tf.keras.layers.Conv3D(c, 1, activation=None)(x)

    result = tf.keras.layers.ReLU()(x)
    for repeat in range(m):
        result = tf.keras.layers.Conv3D(c, (1, 3, 3),
                                        strides=1,
                                        padding='same')(result)
        result = tf.keras.layers.ReLU()(result)
    result = tf.keras.layers.Add()([crop, result])
    return result


def upscale(x, size_xy):
    result = tf.keras.layers.UpSampling3D(size=(1, size_xy, size_xy))(x)
    return result


def stack(x, skip):
    # NOTE axis 0 is the batch
    # axis 1 is along z = depth dimension
    result = tf.keras.layers.Concatenate(axis=1)([x, skip])
    return result

In [None]:
inputs = tf.keras.layers.Input(shape=INPUT_SHAPE)
inputs

In [None]:
inputs = tf.dtypes.cast(inputs, tf.float16)
inputs

# 3 layer UNET 512 * 512 * 10 context

In [None]:
def unet_3_512(input_shape, output_shape):
    inputs = tf.keras.layers.Input(shape=input_shape, dtype=tf.float16)
    #inputs = tf.dtypes.cast(x, tf.float16)
    print(inputs)
    skips = []

    x = down_block(inputs, 0, 1, 64, 5)
    print(x)
    skips.append(x)
    x = pool(x, 4)
    print(x)

    x = down_block(x, 0, 2, 128, 5)
    print(x)
    skips.append(x)
    x = pool(x, 4)
    print(x)

    x = down_block(x, 0, 2, 256, 5)
    print(x)
    skips.append(x)

    x = fc_block(x, 2)
    print(x)

    x = stack(skips[-1], x)
    print(x)
    x = up_block(x, 1, 128)
    print(x)

    x = upscale(x, 4)
    print(x)
    x = stack(skips[-2], x)
    print(x)
    x = up_block(x, 1, 64)
    print(x)

    x = upscale(x, 4)
    print(x)
    x = stack(skips[-3], x)
    print(x)
    x = up_block(x, 1, 1)
    print(x)

    # TODO Shouldnt this be 1 * 1 * 1 conv
    x = tf.keras.layers.Conv3D(filters=output_shape[-1],
                               kernel_size=(28, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               padding='valid')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# DATA SHAPE
CONTEXT = 10
BATCH_SIZE = 2
SIZE = 512

INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)
OUTPUT_SHAPE = (1, SIZE, SIZE, len(STRUCTURE_NAMES))

train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
model = unet_3_512(INPUT_SHAPE, OUTPUT_SHAPE)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.summary()

In [None]:
# model.fit_generator(generator=train_gen,
#                     validation_data=valid_gen,
#                     steps_per_epoch=len(train_paths) // BATCH_SIZE,
#                     epochs=EPOCHS)

# 3 Layer unet 128 * 128

In [None]:
def unet_3_128(input_shape, output_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    skips = []

    x = down_block(inputs, 0, 1, 64, 1)
    print(x)
    skips.append(x)
    x = pool(x, 4)
    print(x)

    x = down_block(x, 0, 1, 128, 1)
    print(x)
    skips.append(x)
    x = pool(x, 4)
    print(x)

    x = down_block(x, 0, 1, 256, 1)
    print(x)
    skips.append(x)

    x = fc_block(x, 2)
    print("FC", x)


    x = stack(skips[-1], x)
    print(x)
    x = up_block(x, 1, 128)
    print(x)

    x = upscale(x, 4)
    print(x)
    x = stack(skips[-2], x)
    print(x)
    x = up_block(x, 1, 64)
    print(x)

    x = upscale(x, 4)
    print(x)
    x = stack(skips[-3], x)
    print(x)
    x = up_block(x, 1, 1)
    print(x)

    x = tf.keras.layers.Conv3D(filters=output_shape[-1],
                               kernel_size=(64, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               padding='valid')(x)
    print(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# DATA SHAPE
CONTEXT = 10
BATCH_SIZE = 2
SIZE = 128

INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)
OUTPUT_SHAPE = (1, SIZE, SIZE, len(STRUCTURE_NAMES))

train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
model = unet_3_128(INPUT_SHAPE, OUTPUT_SHAPE)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.summary()

In [None]:
model.fit_generator(generator=train_gen,
                    validation_data=valid_gen,
                    steps_per_epoch=len(train_paths) // BATCH_SIZE,
                    epochs=EPOCHS)

# 2 Layer unet

In [None]:
def unet_model_2_layer(input_shape, output_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    skips = []

    x = down_block(inputs, 0, 1, 128, 9)
    print(x)
    skips.append(x)
    x = pool(x, 8)
    print(x)

    x = down_block(x, 0, 2, 256, 7)
    print(x)
    skips.append(x)
    x = pool(x, 8)
    print(x)


    x = fc_block(x, 2)
    print(x)


    x = upscale(x, 8)
    print(x)
    x = stack(skips[-1], x)
    print(x)
    x = up_block(x, 1, 128)
    print(x)

    x = upscale(x, 8)
    print(x)
    x = stack(skips[-2], x)
    print(x)
    x = up_block(x, 1, 1)
    print(x)

    x = tf.keras.layers.Conv3D(filters=output_shape[-1],
                               kernel_size=(15, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               padding='valid')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# DATA SHAPE
CONTEXT = 10
BATCH_SIZE = 2
SIZE = 512
INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)

train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
model = unet_3_128(INPUT_SHAPE, OUTPUT_SHAPE)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.summary()

In [None]:
model.fit_generator(generator=train_gen,
                    validation_data=valid_gen,
                    steps_per_epoch=len(train_paths) // BATCH_SIZE,
                    epochs=EPOCHS)

# 1 Layer unet

In [None]:
def unet_model_1_layer(input_shape, output_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    skips = []

    x = down_block(inputs, 0, 3, 32, 3)
    print(x)
    skips.append(x)
    x = pool(x, 8)
    print(x)


    x = fc_block(x, 2, inp = 2048, out=32)
    print("fc", x)

    x = upscale(x, 8)
    print(x)
    x = stack(skips[-1], x)
    print(x)
    x = up_block(x, 1, 1)
    print(x)

    x = tf.keras.layers.Conv3D(filters=output_shape[-1],
                               kernel_size=(30, 1, 1),
                               strides=1,
                               activation='sigmoid',
                               padding='valid')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
# DATA SHAPE
CONTEXT = 10
BATCH_SIZE = 50
SIZE = 512

INPUT_SHAPE = (2*CONTEXT + 1, SIZE, SIZE, 1)
OUTPUT_SHAPE = (1, SIZE, SIZE, len(STRUCTURE_NAMES))

train_gen = DataGen(train_paths,
                       context_paths,
                       label_paths,
                       context=CONTEXT,
                       batch_size=BATCH_SIZE,
                       structure_names=STRUCTURE_NAMES,resize=SIZE)

valid_gen = DataGen(valid_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

test_gen = DataGen(test_paths,
                    context_paths,
                    label_paths,
                    context=CONTEXT,
                    batch_size=BATCH_SIZE,
                    structure_names=STRUCTURE_NAMES, resize=SIZE)

In [None]:
model = unet_3_128(INPUT_SHAPE, OUTPUT_SHAPE)
model.compile(optimizer=OPTIMIZER, loss=LOSS)
model.summary()

In [None]:
model.fit_generator(generator=train_gen,
                    validation_data=valid_gen,
                    steps_per_epoch=len(train_paths) // BATCH_SIZE,
                    epochs=EPOCHS)