# Image Optimization
Gradient based image optimization

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
from models import alexnet_wrapper
from utils import norm_image

### Set up tensorflow graph

In [None]:
params = {
    'regularization_scale': 1e-4,
    'learning_rate': 0.05,
}

In [None]:
CHECKPOINT_PATH = 'checkpoints/model.ckpt-115000'

In [None]:
NUM_STEPS = 64

#### Initialize image as random noise

In [None]:
image_initializer = tf.random_uniform_initializer(
    minval=0,
    maxval=1,
)

image_regularizer = tf.contrib.layers.l2_regularizer(
    scale=params['regularization_scale']
)

image_shape = (1, 128, 128, 3)
images = tf.get_variable(
    "images",
    image_shape,
    initializer=image_initializer,
    regularizer=image_regularizer)
    
print(images)

#### get tensor we want to optimize

In [None]:
tensor_name = 'conv' # this is the 1st conv layer of alexnet
model_output = alexnet_wrapper(
    images,
    tensor_name=tensor_name,
    train=False
)

print(model_output)

#### specify loss to minimize
In this example, I'm going to optimize for high activity for a single unit:
(x, y) = (4, 4)
channel = 16

(and add regularization)

In [None]:
total_regularization = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
loss = tf.negative(model_output[:, 4, 4, 16]) + total_regularization
print(loss)

#### now we need to minimize the loss

In [None]:
variables_to_train = [var for var in tf.trainable_variables() if "images:0" == var.name]
print(variables_to_train)

In [None]:
optimizer = tf.train.AdamOptimizer(params['learning_rate'])
train_op = optimizer.minimize(loss, var_list=variables_to_train)

#### create a Session and restore model weights

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(
    var_list=[v for v in all_variables if "images" not in v.name and "beta" not in v.name]
)
saver.restore(sess, CHECKPOINT_PATH)

In [None]:
loss_list = list()
image_list = list()

for step in range(NUM_STEPS):
    loss_list.append(sess.run(loss))
    image_list.append(norm_image(sess.run(images)))
    sess.run(train_op)

#### Plot outputs

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(loss_list, c='k', linewidth=4)

# plot formatting
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylabel('Loss')
ax.set_xlabel('Steps')


In [None]:
stacked_image_list = np.squeeze(np.stack(image_list))

In [None]:
from matplotlib import animation
from IPython.display import HTML

plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above, uses JavaScript

fig, ax = plt.subplots(figsize=(10, 10))
l = ax.imshow(np.zeros((128, 128, 3)))

def animate(i):
    l.set_data(stacked_image_list[i])

ani = animation.FuncAnimation(fig, animate, frames=NUM_STEPS)
ani

# Compare to true weights

In [None]:
weights_tensor = tf.get_default_graph().get_tensor_by_name("conv1/weights:0")

In [None]:
weights = sess.run(weights_tensor)

In [None]:
plt.imshow(norm_image(weights[:, :, :, 16]))

# For fun, let's look at all the weights
We can choose one and see what the optimized output looks like

In [None]:
# move the last axis to the front so we can easily iterate over it
weights = np.moveaxis(weights, 3, 0)

In [None]:
fig, axes = plt.subplots(figsize=(24, 16), nrows=8, ncols=12)
for kernel_idx, (kernel, ax) in enumerate(zip(weights, axes.ravel())):
    ax.imshow(norm_image(kernel))
    ax.axis('off')
    ax.set_title(kernel_idx)