In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datasets.skyline12 import Skyline12
from sklearn.model_selection import train_test_split

skyline12 = Skyline12('/storage/skyline12/data')
train_set, test_set = train_test_split(list(skyline12), train_size=100, random_state=0)
train_set, validation_set = train_test_split(train_set, train_size=0.8, random_state=42)
len(train_set), len(validation_set), len(test_set)

In [None]:
import tensorflow as tf
from datasets.skyline12 import create_augment_fn

augment = create_augment_fn()
NUM_CLASSES = 5


def preprocess(x, y, z):
    x, y, z = augment(x, y, z)
    x = x.astype('float32')
    x /= 255.0
    y[y >= NUM_CLASSES - 1] = NUM_CLASSES - 1
    y = tf.keras.utils.to_categorical(y, num_classes=NUM_CLASSES)
    return x, y  # ignore z


FOLDS = 50
train_ds = tf.data.Dataset.from_generator(
    lambda: (preprocess(x, y, z) for x, y, z in train_set),
    (tf.dtypes.float32, tf.dtypes.uint8),
    (tf.TensorShape([512, 512, 3]), tf.TensorShape([512, 512, NUM_CLASSES]))
).repeat(FOLDS)
validation_ds = tf.data.Dataset.from_generator(
    lambda: (preprocess(x, y, z) for x, y, z in validation_set),
    (tf.dtypes.float32, tf.dtypes.uint8),
    (tf.TensorShape([512, 512, 3]), tf.TensorShape([512, 512, NUM_CLASSES]))
).repeat(FOLDS)

In [None]:
from models.unet import create_unet
from metrics import CategoricalMeanIou

unet = create_unet()
unet.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=[CategoricalMeanIou(num_classes=5), 'accuracy'],
    run_eagerly=False
)

In [None]:
early_stopper = tf.keras.callbacks.EarlyStopping(patience=10)

unet.fit(
    train_ds.batch(1),
    epochs=200,
    validation_data=validation_ds.batch(1),
    callbacks=[early_stopper]
)