In [None]:
import random
from typing import List
import importlib

import tensorflow.keras as keras
import matplotlib.pyplot as plt

import config
from dataset_utils import mk_dataset
from model import losses, unet
from utils import callbacks

In [None]:
# Get Datasets
def make_datasets(
    tr_path: List[str], va_path: List[str], use_cumix: bool, nbmix: int = 3
):
    """
    データセット作成。`use_cutmix`でCutmix適用を決める。
    """
    train_ds = mk_dataset.mk_base_dataset(
        path_list=tr_path, sat_path=config.TR_SAT_PATH, map_path=config.TR_MAP_PATH
    )
    if use_cumix:
        train_ds = mk_dataset.augument_ds(train_ds, nbmix)
    train_ds = mk_dataset.post_process_ds(train_ds)

    valid_ds = mk_dataset.mk_base_dataset(
        path_list=va_path, sat_path=config.TR_SAT_PATH, map_path=config.TR_MAP_PATH
    )
    valid_ds = mk_dataset.post_process_ds(valid_ds)
    return train_ds, valid_ds


In [None]:
# Define model
def compile_model(loss):
    model = unet.big_unet_model(
        input_shape=config.INPUT_SIZE,
        output_channels=config.OUT_CH,
    )
    # Compile the model
    optimizer = keras.optimizers.Adam()
    metrics = [keras.metrics.MeanIoU(num_classes=2)]
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    return model


In [None]:
# Define Callbacks
def get_callbacks(filename):
    tboard_cb = callbacks.get_tboard_callback(str(config.LOG_PATH / filename))
    checkpoint_cb = callbacks.get_checkpoint_callback(
        str(config.CHECKPOINT_PATH / filename / filename)
    )
    callback_list = [tboard_cb, checkpoint_cb]
    return callback_list


In [None]:
pathlist = config.TR_MAP_PATH.glob("*.png")
pathlist = [path.name for path in pathlist]
random.shuffle(pathlist)

nb_tr = int(len(pathlist) * 0.8)
nb_va = int(len(pathlist) * 0.2)
tr_pathlist = pathlist[:nb_tr]
va_pathlist = pathlist[nb_tr:]
importlib.reload(mk_dataset)
train_ds, valid_ds = make_datasets(
    tr_path=tr_pathlist,
    va_path=va_pathlist,
    use_cumix=False,
)
# print(args.epochs)


In [None]:
loss = losses.TverskyLoss(name="Tversky", alpha=0.7)
model = compile_model(loss=loss)
## 訓練
filename = "test1124"
model_history = model.fit(
    train_ds,
    epochs=20,
    validation_data=valid_ds,
    steps_per_epoch=config.STEPS_PER_EPOCH,
    validation_steps=5,
    callbacks=get_callbacks(filename),
)


In [None]:
print(model_history.history.keys())
tr_loss = model_history.history["loss"]
va_loss = model_history.history["val_loss"]
tr_iou = model_history.history["mean_io_u"]
va_iou = model_history.history["val_mean_io_u"]

plt.figure()
plt.title("mean IoU")
plt.plot(tr_iou, "r")
plt.plot(va_iou, "bo")
plt.show()

plt.title("loss")
plt.plot(tr_loss, "r")
plt.plot(va_loss, "bo")
plt.show()


In [None]:
for i,t in valid_ds.take(3):
    pred = model.predict(i)
    plt.imshow(i[0])
    plt.show()
    plt.imshow(t[0])
    plt.show()
    plt.imshow(pred[0])
    plt.show()