In [None]:
%cd ..

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import display, Markdown

from evgena.dataset import Dataset
from evgena.model import TrainableTfModel

In [None]:
# TODO move to network as supportive routines
def conv_bn(x, filters, kernel_size, stride, padding, is_training):
    x = tf.layers.conv2d(
        x, filters, (kernel_size, kernel_size),
        strides=(stride, stride), padding=padding, use_bias=False
    )
    x = tf.layers.batch_normalization(x, training=is_training)
    x = tf.nn.relu(x)
    
    return x

def residual_conv_bn(x, filters, kernel_size, stride, padding, is_training):
    shortcut = x
    x = tf.layers.conv2d(
        x, filters, (kernel_size, kernel_size),
        strides=(stride, stride), padding=padding, use_bias=False
    )
    x = tf.add(x, shortcut)
    x = tf.layers.batch_normalization(x, training=is_training)
    x = tf.nn.relu(x)
    
    return x

## Dataset preparation ##

### Complete surrogate ###

In [None]:
source_dataset = Dataset.from_nprecord('datasets/split_fashion_mnist.npz')
target_model_path = 'models/different_seeds/2018-05-29_190930.bs-0128.lr-0.0010.seed-42/30-best_loss'
target_model = TfModel(target_model_path)
target_dataset_path = 'datasets/2018-05-29_190930_fashion_mnist.npz'

target_dataset = Dataset.from_splits(
    source_dataset.train.X, target_model(source.dataset.train.X),
    source_dataset.val.X, target_model(source.dataset.val.X),
    source_dataset.test.X, target_model(source.dataset.test.X),
    metadata=dict(target_model_path=target_model_path, **source_dataset.metadata)
)
target_dataset.to_nprecord(target_dataset_path)

### Targetted surrogate ###

In [None]:
source_dataset = Dataset.from_nprecord('datasets/split_fashion_mnist.npz')
target_model_path = 'models/different_seeds/2018-05-29_190930.bs-0128.lr-0.0010.seed-42/30-best_loss'
target_model = TfModel(target_model_path)
target_label = 0
target_dataset_path = 'datasets/2018-05-29_190930.0_fashion_mnist.npz'

In [None]:
target_labels = []
for split_name in ['train', 'val', 'test']:
    source_split = getattr(source_dataset, split_name)
    target_predictions = target_model(source_split.X)[:, target_label]
    target_labels.append(
        np.stack((target_predictions, 1 - target_predictions), axis=-1)
    )
train_labels, val_labels, test_labels = target_labels

metadata = source_dataset.metadata
metadata.update(
    target_model_path=np.asarray(target_model_path),
    target_label=np.asarray(target_label),
    synset=np.array([True, False])
)
target_dataset = Dataset.from_splits(
    source_dataset.train.X, train_labels,
    source_dataset.val.X, val_labels,
    source_dataset.test.X, test_labels,
    metadata=metadata
)
target_dataset.to_nprecord(target_dataset_path)

## Model training ##

In [None]:
def binary_cnn(images, labels, is_training, global_step):
    x = images
    
    x = conv_bn(x, 32, 3, 1, 'same', is_training)
    x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))
    x = conv_bn(x, 64, 3, 1, 'same', is_training)
    x = tf.layers.max_pooling2d(x, (2, 2), (2, 2))
    x = conv_bn(x, 128, 3, 1, 'same', is_training)
    x = tf.layers.flatten(x)
    x = tf.layers.dropout(x, training=is_training)
    x = tf.layers.dense(x, 128, activation=tf.nn.relu)
    x = tf.layers.dropout(x, training=is_training)
    
    return x

In [None]:
model = TrainableTfModel.construct(binary_cnn, 'datasets/2018-05-29_190930.0_fashion_mnist.npz', 128, 0.001, tag='test_sigmoid_surrogate', inference_batch_size=4096)

In [None]:
model.train(epochs=60)