# Image Optimization
Gradient based image optimization

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
from models import alexnet_wrapper
from utils import norm_image, total_variation_loss
import xforms

### Set up tensorflow graph

In [None]:
params = {
    'regularization_scale': 1e-4,
    'learning_rate': 0.05,
    'total_variation_weight': .5
}

In [None]:
CHECKPOINT_PATH = 'checkpoints/model.ckpt-115000'

In [None]:
NUM_STEPS = 128

#### Initialize image as random noise

In [None]:
image_initializer = tf.random_uniform_initializer(
    minval=0,
    maxval=1,
)

image_regularizer = tf.contrib.layers.l2_regularizer(
    scale=params['regularization_scale']
)

image_shape = (1, 128, 128, 3)
images = tf.get_variable(
    "images",
    image_shape,
    initializer=image_initializer,
    regularizer=image_regularizer)
    
print(images)

#### do preprocessing here

In [None]:
scales = [1 + (i - 5) / 50. for i in range(11)]
angles = list(range(-10, 11)) + 5 * [0]
print(scales)
print(angles)

In [None]:
images = xforms.pad(images, pad_amount=12)
images = xforms.jitter(images, jitter_amount=8)
images = xforms.random_scale(images, scales)
images = xforms.random_rotate(images, angles)
images = xforms.jitter(images, jitter_amount=4)

#### get tensor we want to optimize

In [None]:
tensor_name = 'conv_3' # this is the 4th conv layer of alexnet
model_output = alexnet_wrapper(
    images,
    tensor_name=tensor_name,
    train=False
)

#### specify loss to minimize
In this example, I'm going to optimize for high activity for a whole channel (1)

(and add regularization)

In [None]:
mean_channel_activity = tf.reduce_mean(model_output[:, :, :, 1])

In [None]:
tv_loss = total_variation_loss(images)

In [None]:
tf_reg = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
total_regularization = tv_loss * params['total_variation_weight'] + tf_reg
    
loss = tf.negative(mean_channel_activity) + total_regularization

#### now we need to minimize the loss

In [None]:
variables_to_train = [var for var in tf.trainable_variables() if "images:0" == var.name]
print(variables_to_train)

In [None]:
optimizer = tf.train.AdamOptimizer(params['learning_rate'])
train_op = optimizer.minimize(loss, var_list=variables_to_train)

#### create a Session and restore model weights

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(
    var_list=[v for v in all_variables if "images" not in v.name and "beta" not in v.name]
)
saver.restore(sess, CHECKPOINT_PATH)

In [None]:
loss_list = list()
tv_list = list()
image_list = list()

tv_list.append(sess.run(tv_loss))

for step in range(NUM_STEPS):
    loss_list.append(sess.run(loss))
    image_list.append(norm_image(sess.run(images)))
    tv_list.append(sess.run(tv_loss))
    sess.run(train_op)

#### Plot outputs

In [None]:
fig, axes = plt.subplots(figsize=(12, 6), ncols=2)
axes[0].plot(loss_list, c='k', linewidth=4)
axes[1].plot(tv_list, c='k', linewidth=4)

# plot formatting
for ax in axes:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('Steps')
    
axes[0].set_ylabel('Total Loss')
axes[1].set_ylabel('TV Loss')

In [None]:
plt.imshow(image_list[-1].squeeze())

In [None]:
from matplotlib import animation
from IPython.display import HTML

plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above, uses JavaScript

fig, ax = plt.subplots(figsize=(10, 10))
l = ax.imshow(np.zeros((128, 128, 3)))

def animate(i):
    l.set_data(np.squeeze(image_list[i]))

ani = animation.FuncAnimation(fig, animate, frames=NUM_STEPS)
ani