In [None]:
# from https://github.com/LukeTonin/simple-deep-learning/blob/main/semantic_segmentation.ipynb
!poetry run python -m pip install git+https://github.com/LukeTonin/simple-deep-learning

In [None]:
import tensorflow as tf
print(tf.__version__)

import numpy as np
print(np.__version__)

import matplotlib
from matplotlib import pyplot as plt
print(matplotlib.__version__)

In [None]:
from simple_deep_learning.mnist_extended.semantic_segmentation import create_semantic_segmentation_dataset

In [None]:
np.random.seed(1)
train_x, train_y, test_x, test_y = create_semantic_segmentation_dataset(num_train_samples=1000,
                                                                        num_test_samples=200,
                                                                        image_shape=(60, 60),
                                                                        max_num_digits_per_image=4,
                                                                        num_classes=3)

In [None]:
train_x.shape

In [None]:
for image, mask in zip(train_x.take(1), train_y.take(1)):
    sample_image, sample_mask = image, mask

In [None]:
import numpy as np
from simple_deep_learning.mnist_extended.semantic_segmentation import display_grayscale_array, plot_class_masks

print(train_x.shape, train_y.shape)

i = np.random.randint(len(train_x))

display_grayscale_array(array=train_x[i])

plot_class_masks(train_y[i])

In [None]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

tf.keras.backend.clear_session()

model = models.Sequential()
model.add(layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu', input_shape=train_x.shape[1:], padding='same'))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.MaxPooling2D(pool_size=(2, 2)))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.UpSampling2D(size=(2, 2)))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.UpSampling2D(size=(2, 2)))
model.add(layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(filters=train_y.shape[-1], kernel_size=(3, 3), activation='sigmoid', padding='same'))

model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.Recall(),
                       tf.keras.metrics.Precision()])

In [None]:
model.summary()

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.Recall(),
                       tf.keras.metrics.Precision()])

In [None]:
history = model.fit(train_x, train_y, epochs=20,
                    validation_data=(test_x, test_y))

In [None]:
test_y_predicted = model.predict(test_x)

In [None]:
from simple_deep_learning.mnist_extended.semantic_segmentation import display_segmented_image

np.random.seed(6)
for _ in range(3):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    i = np.random.randint(len(test_y_predicted))
    print(f'Example {i}')
    display_grayscale_array(test_x[i], ax=ax1, title='Input image')
    display_segmented_image(test_y_predicted[i], ax=ax2, title='Segmented image')
    plot_class_masks(test_y[i], test_y_predicted[i], title='y target and y predicted sliced along the channel axis')