In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datasets.skyline12 import Skyline12
from sklearn.model_selection import train_test_split

skyline12 = Skyline12('/storage/skyline12/data')
train_set, test_set = train_test_split(list(skyline12), train_size=100, random_state=0)
train_set, validation_set = train_test_split(train_set, train_size=0.8, random_state=42)
len(train_set), len(validation_set), len(test_set)

(80, 20, 20)

In [3]:
import tensorflow as tf
import numpy as np
from datasets.skyline12 import create_augment_fn

augment = create_augment_fn()
NUM_CLASSES = 5


def preprocess(x, y, z):
    x, y, z = augment(x, y, z)
    x = x.astype('float32')
    x /= 255.0
    y[y >= NUM_CLASSES - 1] = NUM_CLASSES - 1
    y = tf.keras.utils.to_categorical(y, num_classes=NUM_CLASSES)
    z[z > 0] = 1
    z = np.expand_dims(z, -1)
    return x, y, z


FOLDS = 10
train_ds = tf.data.Dataset.from_generator(
    lambda: (preprocess(x, y, z) for x, y, z in train_set),
    (tf.dtypes.float32, tf.dtypes.uint8, tf.dtypes.float32),
    (
        tf.TensorShape([512, 512, 3]),
        tf.TensorShape([512, 512, NUM_CLASSES]),
        tf.TensorShape([512, 512, 1])
    )
).map(lambda x, y, z: (x, [y, z])).repeat(FOLDS)
validation_ds = tf.data.Dataset.from_generator(
    lambda: (preprocess(x, y, z) for x, y, z in validation_set),
    (tf.dtypes.float32, tf.dtypes.uint8, tf.dtypes.float32),
    (
        tf.TensorShape([512, 512, 3]),
        tf.TensorShape([512, 512, NUM_CLASSES]),
        tf.TensorShape([512, 512, 1])
    )
).map(lambda x, y, z: (x, [y, z])).repeat(FOLDS // 2)

In [3]:
from models.unet import create_unet

unet = create_unet()
unet.load_weights('checkpoints/baseline-weights.h5')

In [22]:
import wandb
from wandb.keras import WandbCallback
from callbacks import WandbLogPredictions

early_stopper = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
class_labels = {
    0: 'Borders',
    1: 'Sky',
    2: 'Unknown',
    3: 'Unknown',
    4: 'Skyscrapers',
}
data_to_log = next(iter(validation_ds.batch(20)))
log_predictions = WandbLogPredictions(data_to_log, class_labels)
AUTOTUNE = tf.data.experimental.AUTOTUNE

wandb.init(project="skyline12-augmentations", tags=['baseline'])
unet.fit(
    train_ds.batch(3).cache('temp/ds_cache/').prefetch(AUTOTUNE),
    epochs=200,
    validation_data=validation_ds.batch(3).cache('temp/ds_cache/').prefetch(AUTOTUNE),
    callbacks=[
        early_stopper,
        WandbCallback(),
        log_predictions,
        tf.keras.callbacks.ModelCheckpoint('checkpoints/best-baseline', save_best_only=True, save_weights_only=True)
    ]
)

[34m[1mwandb[0m: Wandb version 0.10.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch 1/200
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: checkpoints/best-baseline/assets
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200


<tensorflow.python.keras.callbacks.History at 0x7f81b807a390>

In [28]:
unet.load_weights('checkpoints/baseline-weights.h5')