In [None]:
%cd ..

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from time import sleep

from evgena.datasets import load_nprecord
from evgena.models import TfModel

In [None]:
ga_run = np.load('playground/ga_runs/18-04-15-19-55-56.npz')

In [None]:
individuals = ga_run['individuals']
objectives = ga_run['objectives']
fitnesses = ga_run['fitnesses']

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(1, 1, figsize=(10,5))
# fig.tight_layout()

ax.set_xlim(0.0000000001, 1)
ax.set_xlabel('Target class prediction probability')
ax.set_ylim(-1, 1)
ax.set_ylabel('mean SSIM')
ax.set_xscale('log')
ax.grid(axis='both')
ax.vlines(0.5, -1, 1, colors='g')

for epoch, curr_objectives in enumerate(objectives):
    ax.lines = []
    ax.plot(*curr_objectives.transpose(), 'r+')
    
    fig.canvas.set_window_title('Current generation: {}'.format(epoch))
    fig.canvas.draw()
    
    sleep(0.05)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(8, 8, figsize=(10, 10))
fig.tight_layout()

for i in range(64):
    if i > len(individuals):
        break
    
    ax[i // 8, i % 8].axis('off')
    ax[i // 8, i % 8].imshow(individuals[i], cmap='plasma', vmin=-1, vmax=1)

In [None]:
class ImageAugmentation:
    def __init__(self):
        graph = tf.Graph()
        self.session = tf.Session(graph=graph)
        
        with graph.as_default():
            # input placeholders
            self.augmentations = tf.placeholder(tf.float32, [None, None, None, 1], name='augmentations')
            self.base_images = tf.placeholder(tf.float32, [None, None, None, 1], name='base_images')  # TODO link dimensions??
            
            # resize augmentations to match images
            resized_augmentations = tf.image.resize_images(
                self.augmentations, tf.shape(self.base_images)[1:3],
                method=tf.image.ResizeMethod.BILINEAR, align_corners=True
            )
            
            # add together with augmentations reshaped
            self.augmented_images = tf.clip_by_value(
                self.base_images + tf.expand_dims(resized_augmentations, 1), 0.0, 1.1
            )

    def __call__(self, augmentations, base_images):
        if len(augmentations.shape[1:]) == 2:
            augmentations = np.expand_dims(augmentations, -1)
        
        if len(base_images.shape[1:]) == 2:
            base_images = np.expand_dims(base_images, -1)
        
        return self.session.run(
            self.augmented_images,
            feed_dict={self.augmentations: augmentations, self.base_images: base_images}
        )
    
augment_images = ImageAugmentation()
model = TfModel('models/fashion_mnist_cnn/model', 'end_points/images', 'end_points/scores', batch_size=8192)

In [None]:
np.mean(model(test_data.reshape(-1, 28, 28, 1))[:, target_class])

In [None]:
np.sum(model(test_data.reshape(-1, 28, 28, 1))[:, target_class] > 0.5)

In [None]:
np.sum(generalization[leaderboard[0]] > 0.5)

In [None]:
%matplotlib notebook

train, test, synset, metadata = load_nprecord('fashion_mnist.npz')

source_class = 0
target_class = 5

prediction_bound = 0.5
ssim_bound = 0.2

filtered_indices, *_ = np.where(np.logical_and(objectives[-1, :, 0] > prediction_bound, objectives[-1, :, 1] > ssim_bound))
filtered_individuals = individuals[filtered_indices]
filtered_objectives = objectives[-1, filtered_indices]

test_data = test.X[test.y == source_class]
test_individuals = filtered_individuals

augmented_images = augment_images(test_individuals, test_data)
augmented_images_batch_shaped = augmented_images.reshape(-1, *augmented_images.shape[2:4], 1)

generalization = model(augmented_images_batch_shaped)[:, target_class].reshape(augmented_images.shape[:2])

print('Mean test prediction: {:4f}, target predicted {} out of {}.'.format(generalization.mean(), np.sum(generalization > 0.5), generalization.size))

leaderboard = np.flip(np.argsort(filtered_objectives[:,1]), 0)

compare_fig, compare_ax = plt.subplots(1, 3, figsize=(13, 6))

In [None]:
individual_i = leaderboard[0]
image_i = 0

compare_ax[0].imshow(test_data[image_i], cmap='gray', vmin=0, vmax=1)
compare_ax[1].imshow(filtered_individuals[individual_i], cmap='plasma', vmin=-1, vmax=1)
compare_ax[2].imshow(augmented_images[individual_i, image_i][:,:,0], cmap='gray', vmin=0, vmax=1)