In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
SEED = 42

In [None]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"

import segmentation_models as sm

In [None]:
tf.random.set_seed(SEED)

In [None]:
data_p_dir = 'data_p/l/'
image_dir = data_p_dir + 'image/'
label_dir = data_p_dir + 'label/'

In [None]:
image_file_paths = tf.data.Dataset.list_files(image_dir + '*.jpg', shuffle=False)
label_file_paths = tf.data.Dataset.list_files(label_dir + '*.jpg', shuffle=False)

dataset = tf.data.Dataset.zip((image_file_paths, label_file_paths))

# for i, l in dataset.take(3):
#     print(i, l)

In [None]:
def process_img(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.
    return img

def process_label(path):
    mask = tf.io.read_file(path)
    img = tf.io.parse_tensor(mask, out_type=tf.float32)
    return img

def process_batch(image, label):
    X = process_img(image)
    y = process_label(label)
    return X, y

dataset = dataset.map(process_batch)

In [None]:
def plot_ds_element_overlay(background, overlay):
    fig, ax = plt.subplots()
    plt.imshow(background)
    ax.imshow(overlay, alpha=0.3)

def plot_ds_element(background, overlay):
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(background)
    axs[1].imshow(overlay)
    plt.show()

# for i, m in dataset.skip(9).take(1):
#     plot_ds_element(i, m)

In [None]:
dataset = dataset.shuffle(buffer_size=10000, seed=SEED)

# Define the split ratio (e.g., 80% for training, 20% for validation)
split_ratio = 0.8
num_samples = dataset.cardinality().numpy()

num_train = int(split_ratio * num_samples)
num_val = num_samples - num_train

# Split the dataset into training and validation sets
train_ds = dataset.take(num_train)
val_ds = dataset.skip(num_train)

In [None]:
model = sm.Unet('efficientnetb2', classes=1, activation='sigmoid')

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
model.compile('adam', sm.losses.DiceLoss(), metrics)

# model.summary()

In [None]:
def callbacks():
    from datetime import datetime
    logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

    checkpoint_path = "model_checkpoint.h5"
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=True,
        save_best_only=True,
        monitor='val_loss',
        mode='min',
        verbose=1
    )

    return [tensorboard_callback, checkpoint_callback]

In [None]:
train_ds = train_ds.batch(4).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.batch(4).prefetch(buffer_size=tf.data.AUTOTUNE)

model.fit(train_ds, epochs=100, validation_data=val_ds, callbacks=callbacks())

In [None]:
SKIP = 80
d = tf.data.Dataset.zip((image_file_paths, label_file_paths)).map(process_batch)

item = d.batch(1).skip(SKIP).take(1)
print(item)

for i, l in item:
    plot_ds_element(i[0], l[0])

for i, l in item:
    p = model.predict(i)
    plot_ds_element(p[0], p[0])