In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
from models import alexnet_wrapper
from utils import norm_image, total_variation_loss
import xforms

### Set up tensorflow graph

In [None]:
params = {
    'regularization_scale': 1e-4,
    'learning_rate': 0.05,
    'total_variation_weight': .5
}

scales = [1 + (i - 5) / 50. for i in range(11)]
angles = list(range(-10, 11)) + 5 * [0]

In [None]:
CHECKPOINT_PATH = 'checkpoints/model.ckpt-115000'

In [None]:
NUM_STEPS = 32

In [None]:
# choose 16 channels at random
channels_to_use = np.random.choice(np.arange(384), size=(16,), replace=False)
channel_list = list()

In [None]:
for channel in channels_to_use:
    tf.reset_default_graph()
    print("-------------Working on channel %d-------------" % channel)
    
    image_initializer = tf.random_uniform_initializer(
        minval=0,
        maxval=1,
    )

    image_regularizer = tf.contrib.layers.l2_regularizer(
        scale=params['regularization_scale']
    )

    image_shape = (1, 128, 128, 3)
    images = tf.get_variable(
        "images",
        image_shape,
        initializer=image_initializer,
        regularizer=image_regularizer)
    
    images = xforms.pad(images, pad_amount=12)
    images = xforms.jitter(images, jitter_amount=8)
    images = xforms.random_scale(images, scales)
    images = xforms.random_rotate(images, angles)
    images = xforms.jitter(images, jitter_amount=4)
    
    tensor_name = 'conv_3'
    model_output = alexnet_wrapper(
        images,
        tensor_name=tensor_name,
        train=False
    )
    
    mean_channel_activity = tf.reduce_mean(model_output[:, :, :, channel])
    tv_loss = total_variation_loss(images)
    tf_reg = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    total_regularization = tv_loss * params['total_variation_weight'] + tf_reg

    loss = tf.negative(mean_channel_activity) + total_regularization
    variables_to_train = [var for var in tf.trainable_variables() if "images:0" == var.name]
    
    optimizer = tf.train.AdamOptimizer(params['learning_rate'])
    train_op = optimizer.minimize(loss, var_list=variables_to_train)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
    saver = tf.train.Saver(
        var_list=[v for v in all_variables if "images" not in v.name and "beta" not in v.name]
    )
    saver.restore(sess, CHECKPOINT_PATH)
    image_list = list()

    for step in range(NUM_STEPS):
        image_list.append(norm_image(sess.run(images)))
        sess.run(train_op)
        
    channel_list.append(image_list)


#### Plot outputs

In [None]:
from matplotlib import animation
from IPython.display import HTML

plt.rcParams["animation.html"] = "jshtml"  # for matplotlib 2.1 and above, uses JavaScript
plt.rcParams['animation.embed_limit'] = 40

fig, axes = plt.subplots(figsize=(20, 20), nrows=4, ncols=4)
axes = axes.ravel()

ax_data_list = list()
for ax in axes:
    ax_data_list.append(ax.imshow(np.zeros((128, 128, 3))))
    ax.axis('off')

def animate(i):
    for idx, ad in enumerate(ax_data_list):
        ad.set_data(np.squeeze(channel_list[idx][i]))

ani = animation.FuncAnimation(fig, animate, frames=NUM_STEPS)
ani