In [None]:
import datetime

import tensorflow.keras as keras
import matplotlib.pyplot as plt
import numpy as np

from model import unet
import config
import utils
from dataset import mk_dataset


In [None]:
train_ds = mk_dataset.mk_dataset(
    SAT_PATH=config.TR_SAT_PATH, MAP_PATH=config.TR_MAP_PATH
)
valid_ds = mk_dataset.mk_dataset(
    SAT_PATH=config.VA_SAT_PATH, MAP_PATH=config.VA_MAP_PATH, batch_size=1
)

sample_inp, sample_tar = next(iter(valid_ds))


In [None]:
input_shape = (config.IMG_HEIGHT, config.IMG_WIDTH, config.IMG_CH)
model = unet.big_unet_model(
    input_shape=input_shape, output_channels=config.OUTPUT_CLASSES
)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.05),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)


In [None]:
model.summary()

In [None]:
disp_cb = utils.DisplayCallback(
    model=model,
    sample_inp=sample_inp[0],
    sample_tar=sample_tar[0],
)
disp_cb.on_epoch_end(0, None)


In [None]:
for i,t in train_ds.take(1):
    print(i.shape)
    print(np.max(i))
    print(np.min(i))
    print(t.shape)
    print(np.max(t))
    print(np.min(t))
    plt.imshow(i[0])
    plt.show()
    plt.imshow(t[0])
    plt.show()

In [None]:
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
)


model_history = model.fit(
    train_ds,
    epochs=config.EPOCHS,
    steps_per_epoch=config.STEPS_PER_EPOCH,
    validation_steps=config.VALIDATION_STEPS,
    validation_data=valid_ds,
    callbacks=[disp_cb, tensorboard_callback],
)


In [None]:
for i,t in valid_ds.take(2):
    t_pred = model.predict(i)
    plt.imshow(i[0])
    plt.show()
    plt.imshow(t[0])
    plt.show()
    pred_mask = utils.create_mask(t_pred)
    plt.imshow(keras.preprocessing.image.array_to_img(pred_mask))
    plt.show()

t_pred = t_pred[0]

In [None]:
loss = model_history.history["loss"]
val_loss = model_history.history["val_loss"]

plt.figure()
plt.plot(model_history.epoch, loss, "r", label="Training loss")
plt.plot(model_history.epoch, val_loss, "bo", label="Validation loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss Value")
plt.ylim([0, 1])
plt.legend()
plt.savefig('10282205.png')
plt.show()