In [None]:
from datetime import datetime

import tensorflow.keras as keras
import matplotlib.pyplot as plt

from model import unet, losses, metrics

import config
from utils import callbacks
from dataset import mk_dataset


In [None]:
# Define Dataset

train_ds = mk_dataset.mk_dataset(
    SAT_PATH=config.TR_SAT_PATH, MAP_PATH=config.TR_MAP_PATH
)
valid_ds = mk_dataset.mk_dataset(
    SAT_PATH=config.VA_SAT_PATH, MAP_PATH=config.VA_MAP_PATH, batch_size=1
)

sample_inp, sample_tar = next(iter(valid_ds))


In [None]:
# Define Model

input_shape = (config.IMG_HEIGHT, config.IMG_WIDTH, config.IMG_CH)
model = unet.big_unet_model(input_shape=input_shape, output_channels=config.OUT_CH)

# Compile the model
optimizer = keras.optimizers.Adam()
loss = losses.DICELoss("DICE")
metrics = ["accuracy", metrics.iou_coef]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)


In [None]:
model.loss.name


In [None]:
# Show Model shape
keras.utils.plot_model(model,show_dtype=True,show_shapes=True)

In [None]:
# Define callbacks

disp_cb = callbacks.DisplayCallback(
    model=model,
    sample_inp=sample_inp[0],
    sample_tar=sample_tar[0],
)
disp_cb.on_epoch_end(0, None)

filename = datetime.now().strftime("%Y%m%d%H%M_") + loss.name

tboard_cb = callbacks.get_tboard_callback(str(config.LOG_PATH / filename))
checkpoint_cb = callbacks.get_checkpoint_callback(
    str(config.CHECKPOINT_PATH / filename)
)


In [None]:
NB_Epochs = 100
model_history = model.fit(
    train_ds,
    epochs=NB_Epochs,
    steps_per_epoch=config.STEPS_PER_EPOCH,
    validation_steps=config.VALIDATION_STEPS,
    validation_data=valid_ds,
    callbacks=[disp_cb, tboard_cb, checkpoint_cb],
)


In [None]:
for i, t in valid_ds.take(2):
    t_pred = model.predict(i)
    plt.imshow(i[0])
    plt.show()
    plt.imshow(t[0])
    plt.show()
    plt.imshow(t_pred[0][:, :, 0])
    plt.show()


In [None]:
loss = model_history.history["loss"]
# val_loss = model_history.history["val_loss"]

plt.figure()
plt.plot(model_history.epoch, loss, "r", label="Training loss")
plt.plot(model_history.epoch, val_loss, "bo", label="Validation loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.ylim([0, 1])
plt.legend()
plt.savefig("10282205.png")
plt.show()


In [None]:
pred = model.predict(sample_inp)
pred.shape