In [None]:
# GoogleDrive をマウント

from google.colab import drive
drive.mount("/gdrive")
# データ解凍
!unzip /gdrive/MyDrive/卒研/datasets_21110115.zip >&/dev/null
! git clone https://github.com/straxFromIbr/UnetModelScript.git
% cd UnetModelScript/home

In [None]:
import tensorflow.keras as keras

import config
from model import unet, losses, metrics
from utils import callbacks
from dataset_utils import mk_dataset


In [None]:
# Define model
def compile_model(loss):
    model = unet.big_unet_model(
        input_shape=config.INPUT_SIZE,
        output_channels=config.OUT_CH,
    )
    # Compile the model
    optimizer = keras.optimizers.Adam()
    metric_list = ["accuracy", metrics.iou_coef]
    model.compile(optimizer=optimizer, loss=loss, metrics=metric_list)
    return model


In [None]:
# Define Callbacks
def get_callbacks(filename):
    tboard_cb = callbacks.get_tboard_callback(str(config.LOG_PATH / filename))
    checkpoint_cb = callbacks.get_checkpoint_callback(
        str(config.CHECKPOINT_PATH / filename / filename)
    )
    callback_list = [tboard_cb, checkpoint_cb]
    return callback_list


In [None]:
def train(model: keras.Model, train_ds, valid_ds, NB_EPOCHS):
    filename = model.loss.name
    model_history = model.fit(
        train_ds,
        epochs=NB_EPOCHS,
        validation_data=valid_ds,
        steps_per_epoch=config.STEPS_PER_EPOCH,
        validation_steps=10,
        callbacks=get_callbacks(filename),
    )
    model.save(str(config.MODEL_SAVE_PATH / filename))
    return model_history


In [None]:
basepath = pathlib.Path("/content/datasets_21110115")
TR_SAT_PATH = basepath / "sat"
TR_MAP_PATH = basepath / "map"
VA_SAT_PATH = basepath / "valid/sat"
VA_MAP_PATH = basepath / "valid/map"

# Get Datasets
train_ds = mk_dataset.mk_base_dataset(TR_SAT_PATH, TR_MAP_PATH)
train_ds = mk_dataset.augument_ds(train_ds)
train_ds = mk_dataset.post_process_ds(train_ds)

valid_ds = mk_dataset.mk_base_dataset(VA_SAT_PATH, VA_MAP_PATH)
valid_ds = mk_dataset.post_process_ds(valid_ds)


In [None]:
%load_ext tensorboard

loss = losses.TverskyLoss(name="Tversky")
model = compile_model(loss=loss)
hist = train(
    model=model,
    train_ds=train_ds,
    valid_ds=valid_ds,
    NB_EPOCHS=10,
)


In [None]:
! tensorboard dev upload --logdir ./logs  \
    --name "Train with Tversky on Colab" \
    --description "検証データにはCutMix未適用" \
    --one_sho