<a href="https://colab.research.google.com/github/pitzer42/img_gen_control/blob/master/Image_Generator_Control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
from ipywidgets import interact as run_ui
from ipywidgets.widgets import FloatSlider
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler


def pca_on_activation_space(model,
                            n_components,
                            n_samples,
                            input_generator=tf.random.normal):
    input_shape = list(model.input_shape)
    input_shape[0] = n_samples
    input_tensor = input_generator(input_shape)
    output_tensor = model(input_tensor)
    output_tensor = tf.reshape(output_tensor, [n_samples, -1])
    output_tensor = StandardScaler().fit_transform(output_tensor)
    pca_fit = PCA(n_components).fit(output_tensor)
    return pca_fit


def create_parametric_input_generator(means, variances, directions, standard_deviation_range=2.5):
    def _g(**kwargs):
        input_tensor = tf.identity(means)
        params = kwargs.values()
        for direction, param, variance in zip(directions, params, variances):
            direction = tf.reshape(direction, [1, -1])
            standard_deviation = tf.sqrt(variance)
            input_tensor += direction * param * standard_deviation * standard_deviation_range
        return input_tensor

    return _g


def create_slider_widget_dict(n_sliders):
    return {
        f'comp_{k}': FloatSlider(min=-1, max=1, step=0.001)
        for k in range(n_sliders)
    }


def plot_model_output(model, input_tensor, img_dim):
    output = model(input_tensor)
    output = tf.reshape(output, img_dim)
    plt.imshow(output, cmap='YlGn')


def control_first_layer(model, n_components, n_samples, img_dim):
    first_layer = model.layers[0]
    other_layers = tf.keras.models.Sequential(model.layers[1:])

    pca_fit = pca_on_activation_space(first_layer, n_components, n_samples)
    parametric_input_generator = create_parametric_input_generator(
        pca_fit.mean_,
        pca_fit.explained_variance_,
        pca_fit.components_)
    sliders = create_slider_widget_dict(n_components)

    def update_ui(**kwargs):
        input_tensor = parametric_input_generator(**kwargs)
        plot_model_output(other_layers, input_tensor, img_dim)

    run_ui(
        update_ui,
        **sliders
    )


In [None]:
model_path = '/content/drive/My Drive/Colab Notebooks/Models/mnist_gen_100_60000'  # @param {type:"string"}
_n_components = 8  # @param {type: "integer"}
_n_samples = 1000  # @param {type: "integer"}
img_width = 28  # @param {type: "integer"}
img_height = 28  # @param {type: "integer"}
_img_dim = [img_width, img_height]

import tensorflow as tf

_model = tf.keras.models.load_model(model_path)

control_first_layer(
    _model,
    _n_components,
    _n_samples,
    _img_dim
)