In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras_unet.models import vanilla_unet, custom_unet
import glob
from PIL import Image

-----------------------------------------
keras-unet init: TF version is >= 2.0.0 - using `tf.keras` instead of `Keras`
-----------------------------------------


In [3]:
orgs = glob.glob("images/dibco/segments/original/*.bmp")
masks = list(map(lambda x: x.replace("original", "gt"), orgs))

In [6]:
from sklearn.model_selection import train_test_split

x_train, x_val, y_train, y_val = train_test_split(orgs, masks, test_size=0.1, random_state=0)

print("x_train: ", len(x_train))
print("y_train: ", len(y_train))
print("x_val: ", len(x_val))
print("y_val: ", len(y_val))

x_train:  3106
y_train:  3106
x_val:  346
y_val:  346


In [7]:
from tensorflow.keras.preprocessing.image import load_img


class PairsGenerator(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, original_img_paths, gt_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.original_img_paths = original_img_paths
        self.gt_img_paths = gt_img_paths

    def __len__(self):
        return len(self.original_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (original, gt) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_original_img_paths = self.original_img_paths[i : i + self.batch_size]
        batch_gt_img_paths = self.gt_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_original_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_gt_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
            # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
            y[j] -= 1
        return x, y

In [8]:
pairgen = PairsGenerator(32, (256, 256), orgs, masks)

In [17]:
x, y = pairgen[0]
print(x.shape, y.shape)
print(x.dtype, y.dtype)
print(x[0].max(), y[0].max())

(32, 256, 256, 3) (32, 256, 256, 1)
float32 uint8
255.0 255


In [18]:
model = custom_unet(input_shape=(256, 256, 3))

In [19]:
from keras.callbacks import ModelCheckpoint


model_filename = 'segm_model_v3.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)

In [20]:
from keras.optimizers import Adam, SGD
from keras_unet.metrics import iou, iou_thresholded
from keras_unet.losses import jaccard_distance

model.compile(
    optimizer=Adam(), 
    #optimizer=SGD(lr=0.01, momentum=0.99),
    loss='binary_crossentropy',
    #loss=jaccard_distance,
    metrics=[iou, iou_thresholded]
)

In [22]:
history = model.fit(
    pairgen,
    steps_per_epoch=len(orgs) // 32,
    epochs=50,
    callbacks=[callback_checkpoint]
)

Epoch 1/50

KeyboardInterrupt: 

In [19]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 16) 432         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256, 256, 16) 64          conv2d[0][0]                     
__________________________________________________________________________________________________
spatial_dropout2d (SpatialDropo (None, 256, 256, 16) 0           batch_normalization[0][0]        
______________________________________________________________________________________________

In [15]:
from notebook.services.config import ConfigManager
cm = ConfigManager().update('notebook', {'limit_output': 200000})