In [None]:
%cd ..

%load_ext line_profiler
from IPython.core.debugger import set_trace

In [None]:
from genetals.core import *
from genetals.callbacks import GAStatus, MultiObjectiveReport
from genetals.operators import TwoPointXover, BiasedMutation, ShuffleOperator, NSGAOperator
from genetals.initializers import RandomStdInit
from evgena.datasets import load_emnist, load_mnist, load_nprecord
from evgena.models import TfModel
from evgena.metrics import SSIM
from evgena.utils.large_files import maybe_download

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
class MultiSigmaRandomStdInit(InitializerBase):
    def __init__(self, individual_shape, sigmas = (1,), mu: np.ndarray = 0):
        super(MultiSigmaRandomStdInit, self).__init__()

        self._individual_shape = individual_shape
        self._sigmas = sigmas
        self._mu = mu

    def __call__(self, population_size: int, *args, **kwargs) -> np.ndarray:
        result = np.random.random((population_size,) + tuple(self._individual_shape))

        split_size = (population_size + (len(self._sigmas) - 1)) // len(self._sigmas)
        for i, split_begin in enumerate(range(0, population_size, split_size)):
            result[split_begin:split_begin+split_size] *= self._sigmas[i]
            
        return self._mu + result

In [None]:
class PrePopulationInit(InitializerBase):
    def __init__(self, prepopulation):
        super(PrePopulationInit, self).__init__()

        self._prepopulation = prepopulation

    def __call__(self, population_size: int, *args, **kwargs) -> np.ndarray:
        assert population_size == len(self._prepopulation), 'Wrong pop size'  # TODO maybe tile or so
        
        return self._prepopulation

In [None]:
class Images2LabelObjectiveFnc(ObjectiveFncBase):
    def __init__(self, model: Model, target_label: int, source_images: np.ndarray, sample_size: int = 16):
        super(Images2LabelObjectiveFnc, self).__init__()
        
        self._ssim = SSIM()
        self._model = model
        self._target_label = target_label
        self._source_images = source_images
        self._sample_size = sample_size
        self._source_i = 0
      
    def __call__(self, individuals: np.ndarray) -> np.ndarray:
        # get indices of images
        # same batch for each individual
        u_bound = self._source_i + self._sample_size
        if u_bound <= len(self._source_images):
            images = self._source_images[self._source_i:u_bound]
            self._source_i = u_bound
        else:
            images = np.concatenate((self._source_images[self._source_i:], self._source_images[:u_bound - len(self._source_images)]))
            self._source_i = u_bound - len(self._source_images)  # TODO reshuffle?
        
        augmented_images = images + np.expand_dims(individuals, 1)
        augmented_images_batch_shaped = augmented_images.reshape(-1, *augmented_images.shape[2:], 1)
        
        # augment random batch with various individuals
        # TODO
        
        # for each individual sample its predictions, copmute ssim mean ssim
        norms = self._ssim(augmented_images_batch_shaped, np.expand_dims(images, 0).repeat(len(individuals), axis=0).reshape(-1, *augmented_images.shape[2:], 1))
        norms = norms.reshape(augmented_images.shape[:2])
        logits = model(augmented_images.reshape((-1,) + augmented_images.shape[2:] + (1,)))[:, self._target_label]
        logits = logits.reshape(augmented_images.shape[:2])
                       
        avg_norms = np.average(norms, axis=-1)
        avg_logits = np.average(logits, axis=-1)
        
        # create array by merging columns
        return np.stack((avg_logits, avg_norms), axis=-1)

In [None]:
model = TfModel('models/fashion_mnist_cnn/model', 'end_points/images', 'end_points/scores', batch_size=8192)

train, test, synset, metadata = load_nprecord('fashion_mnist.npz')

In [None]:
source_class = 0
target_class = 5
images = train.X[train.y == source_class]

In [None]:
graph = OperatorGraph()

select_op = ShuffleOperator(graph.init_op)
xover_op = TwoPointXover(select_op, 0.6)
mutation_op = BiasedMutation(xover_op, sigma=0.1, l_bound=-1.0, u_bound=1.0)
moea_op = NSGAOperator(graph.init_op, mutation_op)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(1, 1, figsize=(10,5))
# fig.tight_layout()

ax.set_xlim(0.0000000000000001, 1)
ax.set_xlabel('Target class prediction probability')
ax.set_ylim(0.6, 1)
ax.set_ylabel('mean SSIM')
ax.set_xscale('log')
ax.grid(axis='both')
ax.vlines(0.5, 0, 1, colors='g')

callbacks = [GAStatus(fig), MultiObjectiveReport(ax)]

- better initialization - continuous sigma
- exponential time to live for samples
- ga.run continue easily
- proper mechanism and standardized format for ga run result storing
- callback for intermediate individual checking
- persisting ga and run
- clipping out of `[0,1]` bounds values

In [None]:
objFnc = Images2LabelObjectiveFnc(model, target_class, images[:64], sample_size=64)

In [None]:
ga = GeneticAlgorithm(
#     initializer=PrePopulationInit(first_run[0].individuals),
    initializer=MultiSigmaRandomStdInit((28, 28), [0.1, 0.2, 0.4]), # TODO population splits with different sigma
    operator_graph=graph,
    objective_fnc=objFnc, # TODO next sample overlaps with last sample, exponential time to live / or subset of dataset
    callbacks=callbacks
)

In [None]:
%time final_pop, fitnesses, objectives = ga.run(population_size=512, generation_cap=256)

In [None]:
first_run = final_pop, fitnesses, objectives

In [None]:
objectives[-1, :, 0].argmax()

In [None]:
objectives[-1, :, 1][objectives[-1, :, 0] > 0.5].argmax()

In [None]:
np.where(objectives[-1, :, 0] > 0.5)[0][30]

In [None]:
objectives.shape

In [None]:
generalization = model(images[64:] + final_pop.individuals[108])[:, target_class]

In [None]:
np.sum(generalization > 0.5)

In [None]:
%matplotlib notebook

individual_i = 108
image_i = 0

compare_fig, compare_ax = plt.subplots(1, 3, figsize=(13, 6))
compare_ax[0].imshow(images[image_i], cmap='gray', vmin=0, vmax=1)
compare_ax[1].imshow(final_pop.individuals[individual_i], cmap='gray', vmin=0, vmax=1)
compare_ax[2].imshow(images[image_i] + final_pop.individuals[individual_i], cmap='gray', vmin=0, vmax=1)

In [None]:
model(np.reshape(final_pop.individuals[45] + images[4472], (-1, 28, 28, 1)))

In [None]:
target_scores = model(np.reshape(final_pop.individuals[45] + images[:32], (-1, 28, 28, 1)))[:, target_class]

In [None]:
target_scores[2]

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(8, 8, figsize=(10, 10))
fig.tight_layout()

In [None]:
for i in range(64):
    ax[i // 8, i % 8].axis('off')
    ax[i // 8, i % 8].imshow(final_pop.individuals[i])
# plt.axisi