# Evidence Transfer

In [None]:
import wandb

FOLDS = 20
wandb.init(project='evidence-transfer', config={
    'unet_loss': 'categorical_crossentropy',
    'connecting_unet_layer': 'expanding_block_64_conv2',
    'q_loss': 'binary_crossentropy',
    'q_lambda': 1.0,
    'optimizer': 'sgd',
    'learning_rate': 1e-4,
    'momentum': 0.9,
    'batch_size': 3,
    'dataset': f'folds{FOLDS}',
    'max_epochs': 100,
    'baseline_model': '1f5s41d8',
}, resume='allow')
hparams = wandb.config

In [None]:
from models.unet import create_unet
from models.evitram import create_evidence_transfer_model
import tensorflow as tf

unet = create_unet()
unet_weights = wandb.restore(
    'model-best.h5',
    run_path=f'vassilis_krikonis/unet-baseline/{hparams["baseline_model"]}'
)
unet.load_weights(unet_weights.name)

q_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(1, (1, 1), padding='same', activation='sigmoid'),
], name='Q')

evitram = create_evidence_transfer_model(
    unet,
    q_model,
    'expanding_block_64_conv2',
    loss_lambda=hparams['q_lambda']
)

In [None]:
import tensorflow as tf

tf.keras.utils.plot_model(evitram, show_layer_names=True, show_shapes=True)

In [None]:
from metrics import CategoricalMeanIou
from tensorflow.keras.optimizers import SGD, Adam

if hparams['optimizer'] == 'sgd':
    optimizer = SGD(learning_rate=hparams['learning_rate'], momentum=hparams['momentum'])
elif hparams['optimizer'] == 'adam':
    optimizer = Adam(learning_rate=hparams['learning_rate'])
else:
    optimizer = hparams['optimizer']

evitram.compile(optimizer=optimizer, metrics=[
    [CategoricalMeanIou(num_classes=5), 'accuracy'],
    ['accuracy']
])

In [None]:
!mkdir -p /tmp/ds_cache/
from datasets.skyline12 import Skyline12

skyline12 = Skyline12('datasets/skyline12/data/')


def split_outputs(x, y, z):
    return x, (y, z)


train_ds = skyline12 \
    .as_tf_dataset(FOLDS, subset='training', cache_dir='/tmp/ds_cache/') \
    .map(split_outputs)
val_ds = skyline12 \
    .as_tf_dataset(FOLDS, subset='validation', cache_dir='/tmp/ds_cache/') \
    .map(split_outputs)

In [None]:
batch_x, _ = next(iter(val_ds.batch(3)))
preds = evitram(batch_x, training=False)

for x, y_pred, z_pred in zip(batch_x, *preds):
    Skyline12.show_sample(x, [y_pred, z_pred], from_tensors=True)

In [None]:
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.data.experimental import AUTOTUNE
from utils import get_new_logdir
from callbacks import LogEviTRAMImagesWandb
from wandb.keras import WandbCallback

log_dir = get_new_logdir(root_dir='./logs')
batch_size = hparams['batch_size']
if wandb.run.resumed:
    evitram.load_weights(wandb.restore('model-best.h5', replace=True).name)
evitram.fit(
    train_ds.batch(batch_size).prefetch(AUTOTUNE),
    epochs=hparams['max_epochs'],
    initial_epoch=wandb.run.step,
    validation_data=val_ds.batch(batch_size).prefetch(AUTOTUNE),
    callbacks=[
        TensorBoard(
            log_dir=log_dir,
            histogram_freq=1,
        ),
        LogEviTRAMImagesWandb(next(iter(val_ds.batch(10)))),
        WandbCallback(save_weights_only=True)
    ]
)

In [None]:
batch_x, _ = next(iter(val_ds.batch(3)))
preds = evitram(batch_x, training=False)

for x, y_pred, z_pred in zip(batch_x, *preds):
    Skyline12.show_sample(x, [y_pred, z_pred], from_tensors=True)

In [None]:
wandb.run.finish()