In [None]:
from datetime import datetime

import tensorflow.keras as keras
import matplotlib.pyplot as plt

from model import unet, losses, metrics

import config
from utils import callbacks
from dataset import mk_dataset


In [None]:
def make_datasets():
    train_ds = mk_dataset.mk_dataset(
        SAT_PATH=config.TR_SAT_PATH,
        MAP_PATH=config.TR_MAP_PATH,
    )
    valid_ds = mk_dataset.mk_dataset(
        SAT_PATH=config.VA_SAT_PATH,
        MAP_PATH=config.VA_MAP_PATH,
        batch_size=1,
    )
    return train_ds, valid_ds


In [None]:
# Define Model
def compile_model(loss, optimizer):
    input_shape = (config.IMG_HEIGHT, config.IMG_WIDTH, config.IMG_CH)
    model = unet.big_unet_model(
        input_shape=input_shape,
        output_channels=config.OUT_CH,
    )
    # Compile the model
    metric_list = ["accuracy", metrics.iou_coef]
    model.compile(optimizer=optimizer, loss=loss, metrics=metric_list)
    return model


In [None]:
# Show Model shape

def get_callbacks(filename):
    tboard_cb = callbacks.get_tboard_callback(str(config.LOG_PATH / filename))
    checkpoint_cb = callbacks.get_checkpoint_callback(
        str(config.CHECKPOINT_PATH / filename)
    )
    callback_list = [tboard_cb, checkpoint_cb]
    return callback_list


In [None]:
def train(train_ds, valid_ds, NB_EPOCHS, loss, optimizer=keras.optimizers.Adam()):
    NB_Epochs = 10
    model = compile_model(loss=loss, optimizer=optimizer)
    filename = datetime.now().strftime("%Y%m%d%H%M_") + model.loss.name
    model_history = model.fit(
        train_ds,
        epochs=NB_Epochs,
        steps_per_epoch=config.STEPS_PER_EPOCH,
        validation_steps=config.VALIDATION_STEPS,
        validation_data=valid_ds,
        callbacks=get_callbacks(filename),
    )
    model.save(str(config.MODEL_SAVE_PATH / filename))
    return model_history


In [None]:
def main():
    lossfunc_list = [
        losses.DICELoss(name="DICE"),
        losses.FocalTverskyLoss("Focal"),
        losses.TverskyLoss("Tversky"),
    ]
    train_ds, valid_ds = make_datasets()
    for loss in lossfunc_list:
        hist = train(
            train_ds=train_ds,
            valid_ds=valid_ds,
            NB_EPOCHS=10,
            loss=loss,
        )


In [None]:
main()

In [None]:
for i, t in valid_ds.take(2):
    t_pred = model.predict(i)
    plt.imshow(i[0])
    plt.show()
    plt.imshow(t[0])
    plt.show()
    plt.imshow(t_pred[0][:, :, 0])
    plt.show()


In [None]:
loss = model_history.history["loss"]
# val_loss = model_history.history["val_loss"]

plt.figure()
plt.plot(model_history.epoch, loss, "r", label="Training loss")
plt.plot(model_history.epoch, val_loss, "bo", label="Validation loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.ylim([0, 1])
plt.legend()
plt.savefig("10282205.png")
plt.show()


In [None]:
pred = model.predict(sample_inp)
pred.shape