In [None]:
%cd ..

%load_ext line_profiler
from IPython.core.debugger import set_trace

In [None]:
from genetals.core import *
from genetals.callbacks import GAStatus, MultiObjectiveReport
from genetals.operators import TwoPointXover, BiasedMutation, ShuffleOperator, NSGAOperator
from genetals.initializers import RandomStdInit
from evgena.datasets import load_emnist, load_mnist, load_nprecord
from evgena.models import Model, TfModel
from evgena.metrics import SSIM
from evgena.utils.large_files import maybe_download

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
class BestImgReport(CallbackBase):
    def __init__(self, ax: plt.Axes = None, best_picker = None):
        super(BestImgReport, self).__init__()
        
        if ax is None:
            self._fig, self._ax = plt.subplots(1, 1)
        else:
            self._fig, self._ax = ax.figure, ax
        
        self._best_picker = (lambda fitness: fitness.argmax()) if (best_picker is None) else best_picker
        
    def __call__(self, ga: GeneticAlgorithm) -> None:
        offspring = ga.capture(-1)
        best_i = self._best_picker(offspring.fitnesses)
        
        self._ax.imshow(offspring.individuals[best_i], cmap='plasma', vmin=-1, vmax=1)

In [None]:
class MultiSigmaRandomInit(InitializerBase):
    def __init__(self, individual_shape, sigmas = (1,), mu: np.ndarray = 0):
        super(MultiSigmaRandomInit, self).__init__()

        self._individual_shape = individual_shape
        self._sigmas = sigmas
        self._mu = mu

    def __call__(self, population_size: int, *args, **kwargs) -> np.ndarray:
        sigmas = np.tile(self._sigmas, (population_size + (len(self._sigmas) - 1)) // len(self._sigmas))[:population_size]
        result = (np.random.random((population_size,) + tuple(self._individual_shape)) * 2) - 1

        return self._mu + result * sigmas[:population_size].reshape(population_size, *([1] * len(self._individual_shape)))

In [None]:
class PrePopulationInit(InitializerBase):
    def __init__(self, prepopulation):
        super(PrePopulationInit, self).__init__()

        self._prepopulation = prepopulation

    def __call__(self, population_size: int, *args, **kwargs) -> np.ndarray:
        assert population_size == len(self._prepopulation), 'Wrong pop size'  # TODO maybe tile or so
        
        return self._prepopulation

In [None]:
class Images2LabelObjectiveFnc(ObjectiveFncBase):
    def __init__(self, model: Model, target_label: int, source_images: np.ndarray, sample_size: int = 64, sample_ttl: float = 0.9, shuffle: bool = True):
        super(Images2LabelObjectiveFnc, self).__init__()
        
        self._ssim = SSIM()
        self._model = model
        self._target_label = target_label
        self._source_images = source_images
        self._sample_size = sample_size
        self._sample_ttl = sample_ttl
        self._shuffle_source = shuffle
        
        if self._shuffle_source:
            self._source_index = np.random.permutation(len(self._source_images))
        else:
            self._source_index = np.arange(len(self._source_images))
        
        self._samples = np.recarray((self._sample_size,), dtype=[('index', np.int32), ('ttl', np.float32)])
        self._samples.index = np.arange(self._sample_size)
        self._samples.ttl = 1
        
        self._source_i = self._sample_size
      
    def __call__(self, individuals: np.ndarray) -> np.ndarray:
        # fetch samples
        images = self._source_images[self._source_index[self._samples.index]]
        
        # resolve ttl of samples
        self._samples.ttl *= self._sample_ttl
        death_mask = self._samples.ttl < np.random.random(len(self._samples))
        
        u_source_i = self._source_i + np.sum(death_mask)
        if  u_source_i > len(self._source_images):
            u_source_i -= len(self._source_images)
            babies = np.concatenate((np.arange(self._source_i, len(self._source_images)), np.arange(u_source_i)))
            np.random.shuffle(self._source_index)
        else:
            babies = np.arange(self._source_i, u_source_i)
        self._source_i = u_source_i
        
        self._samples.index[death_mask] = babies
        self._samples.ttl[death_mask] = 1
        
        # augment images
        augmented_images = images + np.expand_dims(individuals, 1)
        np.clip(augmented_images, 0, 1, out=augmented_images)
        augmented_images_batch_shaped = augmented_images.reshape(-1, *augmented_images.shape[2:], 1)
        
        # for each individual sample its predictions, copmute ssim mean ssim
        norms = self._ssim(augmented_images_batch_shaped, np.expand_dims(images, 0).repeat(len(individuals), axis=0).reshape(-1, *augmented_images.shape[2:], 1))
        norms = norms.reshape(augmented_images.shape[:2])
        logits = model(augmented_images.reshape((-1,) + augmented_images.shape[2:] + (1,)))[:, self._target_label]
        logits = logits.reshape(augmented_images.shape[:2])
                       
        avg_norms = np.average(norms, axis=-1)
        avg_logits = np.average(logits, axis=-1)
        
        # create array by merging columns
        return np.stack((avg_logits, avg_norms), axis=-1)

In [None]:
model = TfModel('models/fashion_mnist_cnn/model', 'end_points/images', 'end_points/scores', batch_size=8192)

train, test, synset, metadata = load_nprecord('fashion_mnist.npz')

In [None]:
source_class = 0
target_class = 5
images = train.X[train.y == source_class]

In [None]:
graph = OperatorGraph()

select_op = ShuffleOperator(graph.init_op)
xover_op = TwoPointXover(select_op, 0.6)
mutation_op = BiasedMutation(xover_op, sigma=0.1, l_bound=-1.0, u_bound=1.0)
moea_op = NSGAOperator(graph.init_op, mutation_op)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(1, 1, figsize=(10,5))
# fig.tight_layout()

ax.set_xlim(0.0000000000000001, 1)
ax.set_xlabel('Target class prediction probability')
ax.set_ylim(-1, 1)
ax.set_ylabel('mean SSIM')
ax.set_xscale('log')
ax.grid(axis='both')
ax.vlines(0.5, -1, 1, colors='g')

callbacks = [GAStatus(fig), MultiObjectiveReport(ax)] # TODO BestImgReport(ax[1], best_picker=lambda fit: np.argmax(np.sum(fit, axis=-1)))]

- uniform vs std norm distributions??
- ga.run continue easily
- proper mechanism and standardized format for ga run result storing
- callback for intermediate individual checking
- persisting ga and run
- clipping out of `[0,1]` bounds values (where? individuals `[-1,1]`)
- ssim border sensitivity??

In [None]:
ga = GeneticAlgorithm(
#     initializer=PrePopulationInit(first_run[0].individuals),
    initializer=MultiSigmaRandomInit((28, 28), (np.exp(np.linspace(0, 5, 100)) - 1) / (np.exp(5) - 1)),
    operator_graph=graph,
    objective_fnc=Images2LabelObjectiveFnc(model, target_class, images, sample_size=64, sample_ttl = 0.95),
    callbacks=callbacks
)

In [None]:
%time final_pop, fitnesses, objectives = ga.run(population_size=512, generation_cap=256)

In [None]:
first_run = final_pop, fitnesses, objectives

In [None]:
prediction_bound = 0.5
ssim_bound = 0.8

filtered_indices, *_ = np.where(np.logical_and(final_pop.objectives[:, 0] > prediction_bound, final_pop.objectives[:, 1] > ssim_bound))
filtered_individuals = final_pop.individuals[filtered_indices]
filtered_objectives = final_pop.objectives[filtered_indices]

In [None]:
test_data = test.X[test.y == source_class]
test_individuals = filtered_individuals

augmented_images = test_data + np.expand_dims(test_individuals, 1)
np.clip(augmented_images, 0, 1, out=augmented_images)
augmented_images_batch_shaped = augmented_images.reshape(-1, *augmented_images.shape[2:4], 1)

generalization = model(augmented_images_batch_shaped)[:, target_class].reshape(augmented_images.shape[:2])

In [None]:
generalization.mean()

In [None]:
np.sum(generalization > 0.5)

In [None]:
filtered_indices[filtered_objectives[:, 1].argmax()]

In [None]:
%matplotlib notebook

individual_i = 141
image_i = 0

compare_fig, compare_ax = plt.subplots(1, 3, figsize=(13, 6))
compare_ax[0].imshow(images[image_i], cmap='gray', vmin=0, vmax=1)
compare_ax[1].imshow(final_pop.individuals[individual_i], cmap='plasma', vmin=-1, vmax=1)
compare_ax[2].imshow(images[image_i] + final_pop.individuals[individual_i], cmap='gray', vmin=0, vmax=1)

In [None]:
%matplotlib notebook

fig, ax = plt.subplots(8, 8, figsize=(10, 10))
fig.tight_layout()

In [None]:
for i in range(64):
    if i > len(filtered_individuals):
        break
    
    ax[i // 8, i % 8].axis('off')
    ax[i // 8, i % 8].imshow(filtered_individuals[i], cmap='plasma', vmin=-1, vmax=1)