In [16]:
from keras.callbacks import ModelCheckpoint, Callback
from keras.optimizers import Adam

from networks.unet import Unet

from utils.batch_generator import BatchGenerator

In [19]:
BATCH_SIZE = 1
VAL_BATCH = 10
IMG_ROWS, IMG_COLS = 256, 256
NB_EPOCHS = 1001

In [36]:
class LossValidateCallback(Callback):

    def __init__(self, train_batch, val_batch, results_file):
        self.train_batch = train_batch
        self.val_batch = val_batch
        
        import os
        basedir = os.path.dirname(results_file)
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        self.results_file = results_file

    def on_epoch_end(self, epoch, logs=None):
        train_imgs, train_masks = self.train_batch
        val_imgs, val_masks = self.val_batch
        train_loss, _ = self.model.evaluate(train_imgs, train_masks)
        val_loss, _ = self.model.evaluate(val_imgs, val_masks)
        text = 'epoch: {0} train_loss: {1}, validation loss: {2}\n'.format(
            epoch, train_loss, val_loss
        )
        with open(self.results_file, 'a') as file:
            file.writelines(text)

In [37]:
def train(data_dir, val_data_dir, results_file):
    batch_gen = BatchGenerator(
        data_dir=data_dir, val_data_dir=val_data_dir, batch_size=BATCH_SIZE
    )
    batch_gen.load_data()
    model = Unet.model(IMG_ROWS, IMG_COLS)
    model.compile(
        optimizer=Adam(lr=1e-4),
        loss=Unet.loss,
        metrics=[Unet.metric]
    )
    checkpoint = ModelCheckpoint(
        filepath='deep_unet_batch_1_epoch_{epoch:02d}.hdf5',
        mode='auto',
        period=50
    )

    model.fit_generator(
        batch_gen.train_batches,
        steps_per_epoch=10,
        epochs=NB_EPOCHS,
        callbacks=[
            checkpoint,
            LossValidateCallback(*batch_gen.generate_test_batch(VAL_BATCH), results_file)
        ],
    )

In [38]:
data_dir = 'dataset/dataset_256/'
val_data_dir = 'dataset/val_dataset_256/'
results_file = './results_data.txt'

In [39]:
train(data_dir, val_data_dir, results_file)

Epoch 1/1001
train_loss: 0.37885230779647827, validation loss: 0.4507673382759094
Epoch 2/1001
train_loss: 0.3532446026802063, validation loss: 0.44476279616355896
Epoch 3/1001
train_loss: 0.2875822186470032, validation loss: 0.3841344118118286
Epoch 4/1001
train_loss: 0.28629666566848755, validation loss: 0.3804261088371277
Epoch 5/1001

KeyboardInterrupt: 