In [None]:
import numpy as np
import tensorflow.keras as keras
import tensorflow as tf

LATENT_SIZE = 64

inputs = keras.Input((28, 28, 1))
x = keras.layers.Conv2D(32, (3, 3), strides=(2, 2), activation="selu", padding="same")(inputs)
x = keras.layers.MaxPooling2D((2, 2), padding="same")(x)
x = keras.layers.Conv2D(32, (3, 3), activation="selu", padding="same")(x)
x = keras.layers.MaxPooling2D(2, 2, padding="same")(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(4 * 4 * 16, activation="selu")(x)
latent = keras.layers.Dense(LATENT_SIZE, activation="sigmoid")(x)

latent_inputs = keras.Input(LATENT_SIZE)
x = keras.layers.Dense(4 * 4 * 16, activation="selu")(latent_inputs)
x = keras.layers.Reshape((4, 4, 16))(x)
x = keras.layers.Conv2DTranspose(32, (4, 4), strides=1, activation="selu", padding="valid")(x)
x = keras.layers.Conv2DTranspose(32, (3, 3), strides=2, activation="selu", padding="same")(x)
x = keras.layers.Conv2DTranspose(32, (2, 2), strides=2, activation="selu", padding="same")(x)
x = keras.layers.Conv2DTranspose(32, (3, 3), strides=1, activation="selu", padding="same")(x)
ae_outputs = keras.layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(x)

encoder = keras.Model(inputs, latent)

decoder = keras.Model(latent_inputs, ae_outputs)

enc_dec = encoder(inputs)

ae_model_outputs = decoder(enc_dec)

ae_model = keras.Model(inputs, ae_model_outputs)

ae_model.compile(loss=keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam())


In [None]:
import tensorflow_datasets as tfds

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [None]:
ae_model.fit(x_train, x_train, epochs=6, validation_data=(x_test, x_test), batch_size=256)

In [None]:
encoder.trainable = False
decoder.trainable = False

In [None]:
res = ae_model(x_test[:4])

In [None]:
import matplotlib.pyplot as plt

plt.imshow(x_test[1])

In [None]:
plt.imshow(res[1])

In [None]:
ae_model.trainable = False

inputs = keras.Input((28, 28, 1))
x = ae_model(inputs)
x = keras.layers.Conv2D(16, (3, 3), strides=(2, 2), activation="selu", padding="same")(x)
x = keras.layers.MaxPooling2D((2, 2), padding="same")(x)
x = keras.layers.Conv2D(16, (3, 3), activation="selu", padding="same")(x)
x = keras.layers.MaxPooling2D((2, 2), padding="same")(x)
x = keras.layers.Conv2D(16, (3, 3), activation="selu", padding="same")(x)
x = keras.layers.MaxPooling2D((2, 2), padding="same")(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(10, activation="selu")(x)
outputs = keras.layers.Softmax()(outputs)
classifier = keras.Model(inputs, outputs)
classifier.compile(loss=keras.losses.SparseCategoricalCrossentropy(), optimizer=keras.optimizers.Adam(), metrics=["accuracy"])

In [None]:
classifier.fit(x_train, y_train, epochs=6, validation_data=(x_test, y_test), batch_size=256)

In [None]:
concept_index = np.isin(y_test, [0, 6, 8, 9])

In [None]:
encoder(x_test[concept_index])

In [None]:
probe_input = encoder(inputs)
probe_output = keras.layers.Dense(1, activity_regularizer=keras.regularizers.L1(0.1), activation="sigmoid")(probe_input)

probe = keras.Model(inputs, probe_output)
probe.compile(loss=keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(), metrics=["accuracy"])

probe_x_train, probe_y_train, probe_x_test, probe_y_test = x_test[:7000], concept_index[:7000], x_test[7000:], concept_index[7000:]

In [None]:
probe.fit(probe_x_train, probe_y_train, validation_data=(probe_x_test, probe_y_test), epochs=25)

In [None]:
encoded_image = encoder(inputs)
maximiser_input = keras.Input(1)
max_plane = keras.layers.Dense(LATENT_SIZE, name="max_plane", activity_regularizer=keras.regularizers.L1(0.01), activation="linear")(maximiser_input)
additive = keras.layers.Add()([max_plane, encoded_image])
image_output = decoder(additive)
image_output._name = "image_output"
maximised_probe_output = keras.layers.Dense(1, name="max_probe_repl", activation="sigmoid")(additive)
maximiser_model = keras.Model([inputs, maximiser_input], [maximised_probe_output, image_output])
look_at_results_output = decoder(additive)
look_at_results_model = keras.Model([inputs, maximiser_input], look_at_results_output)

maximiser_model.get_layer("max_probe_repl").set_weights(probe.layers[-1].get_weights())
maximiser_model.get_layer("max_probe_repl").trainable = False

maximiser_model.get_layer("max_plane").trainable = True

maximiser_model.compile(loss=["binary_crossentropy", "mean_squared_error"], loss_weights=[1, 0.25], optimizer=keras.optimizers.Adam())

In [None]:
case = np.repeat(np.expand_dims(probe_x_test[73], axis=0), 64, axis=0)
ones = np.expand_dims(np.ones(64), -1)

In [None]:
mpl.rcParams['figure.dpi'] = 100
plt.imshow(case[3])

In [None]:
maximiser_model.fit([case, ones], [ones, decoder(encoder(case))], epochs=75, batch_size=32)

In [None]:
import matplotlib as mpl
import matplotlib.style

mpl.style.use("seaborn-muted")
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['font.family'] = "serif"

_, image = maximiser_model([case, ones])

_, axs = plt.subplots(1, 2, figsize=(12, 12))
axs[1].axis("off")
axs[1].imshow(np.clip(image[0], 0.1, 1), cmap="gray")
axs[1].set_title("Maximised", fontdict={"fontsize": 35})

axs[0].axis("off")
axs[0].imshow(case[0], cmap="gray")
axs[0].set_title("Normal image", fontdict={"fontsize": 35})
