In [None]:
%set_env SM_FRAMEWORK=tf.keras

In [None]:
import import_ipynb
import dataset

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import segmentation_models as sm
from datetime import datetime

In [None]:
SEED = 42
tf.random.set_seed(SEED)

In [None]:
data_p_dir = 'data_p/s/'
image_dir = data_p_dir + 'image/'
label_dir = data_p_dir + 'label/'

dataset = dataset.train_dataset(image_dir, label_dir)
dataset = dataset.shuffle(buffer_size=10000, seed=SEED)

split_ratio = 0.8
num_samples = dataset.cardinality().numpy()
num_train = int(split_ratio * num_samples)

# Split the dataset into training and validation sets
train_ds = dataset.take(num_train)
val_ds = dataset.skip(num_train)

In [None]:
model = sm.Unet('efficientnetb2', classes=1, activation='sigmoid')

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
model.compile('adam', sm.losses.DiceLoss(), metrics)

# model.summary()

In [None]:
def make_callbacks():
    logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

    checkpoint_path = "model_checkpoint.h5"
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=True,
        save_best_only=True,
        monitor='val_loss',
        mode='min',
        verbose=1
    )

    return [tensorboard_callback, checkpoint_callback]

In [None]:
BATCH_SIZE = 4
train_ds = train_ds.batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(buffer_size=tf.data.AUTOTUNE)

model.fit(train_ds, epochs=10, validation_data=val_ds, callbacks=make_callbacks())