In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets.skyline12 import Skyline12
from sklearn.model_selection import train_test_split

skyline12 = Skyline12('/storage/skyline12/data')
train_set, test_set = train_test_split(list(skyline12), train_size=100, random_state=0)
train_set, validation_set = train_test_split(train_set, train_size=0.8, random_state=42)
len(train_set), len(validation_set), len(test_set)

(80, 20, 20)

In [3]:
import tensorflow as tf
from datasets.skyline12 import create_augment_fn
from functools import partial

NUM_CLASSES = 5
FOLDS = 20


def ds_gen(sample_set, folds):
    augment = create_augment_fn()

    def preprocess(x, y, z):
        x, y, z = augment(x, y, z)
        x = x.astype('float32')
        x /= 255.0
        y[y >= NUM_CLASSES - 1] = NUM_CLASSES - 1
        y = tf.keras.utils.to_categorical(y, num_classes=NUM_CLASSES)
        return x, y  # ignore z

    for _ in range(folds):
        for x, y, z in sample_set:
            yield preprocess(x, y, z)


train_ds = tf.data.Dataset.from_generator(
    partial(ds_gen, train_set, FOLDS),
    (tf.dtypes.float32, tf.dtypes.uint8),
    (tf.TensorShape([512, 512, 3]), tf.TensorShape([512, 512, NUM_CLASSES]))
)
validation_ds = tf.data.Dataset.from_generator(
    partial(ds_gen, validation_set, FOLDS // 2),
    (tf.dtypes.float32, tf.dtypes.uint8),
    (tf.TensorShape([512, 512, 3]), tf.TensorShape([512, 512, NUM_CLASSES]))
)

In [4]:
from models.unet import create_unet
from metrics import CategoricalMeanIou

unet = create_unet()
unet.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=[CategoricalMeanIou(num_classes=5), 'accuracy'],
    run_eagerly=False
)

In [5]:
!mkdir -p temp
import wandb
from wandb.keras import WandbCallback
from callbacks import WandbLogPredictions

early_stopper = tf.keras.callbacks.EarlyStopping(patience=5)
checkpointer = tf.keras.callbacks.ModelCheckpoint(
    'checkpoints/baseline-weights.h5', save_best_only=True, save_weights_only=True)
class_labels = {
    0: 'Borders',
    1: 'Sky',
    2: 'Unknown',
    3: 'Unknown',
    4: 'Skyscrapers',
}
data_to_log = next(iter(validation_ds.batch(20)))
log_predictions = WandbLogPredictions(data_to_log, class_labels)
AUTOTUNE = tf.data.experimental.AUTOTUNE

wandb.init(project="skyline12-augmentations", tags=['baseline'])
unet.fit(
    train_ds.batch(3).cache('temp/').prefetch(AUTOTUNE),
    epochs=200,
    validation_data=validation_ds.batch(3).cache('temp/').prefetch(AUTOTUNE),
    callbacks=[early_stopper, WandbCallback(), log_predictions, checkpointer]
)

[34m[1mwandb[0m: Wandb version 0.10.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200


<tensorflow.python.keras.callbacks.History at 0x7f987413f630>