In [None]:
import torch
import numpy as np
import imageio
import matplotlib.pyplot as plt

from lucent.optvis import render, param, objectives
from lucent.modelzoo import inceptionv1


In [None]:

def store_images(images, location):
    """
    Stores list of images for unit at location. 
    """

    os.makedirs(location, exist_ok=True)

    for idx, img in enumerate(images):
        filename = f"image_{idx}.png"
        imageio.imwrite(os.path.join(location, filename), np.uint8(img*255))

In [None]:
def grad_norm_interrupt(optimizer, params):
    """
    Interrupts the optimization process once gradient norms are smaller than 7,
    which is an empirically found value at which the optimization has reached colorful images.
    """
    return render.gradient_norm_interrupt(optimizer, params, 7)

In [None]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model = inceptionv1(pretrained=True)
model.to(device).eval()

param_f = lambda: param.image(224, batch=9)
opt = lambda params: torch.optim.Adam(params, 5e-2)
layer = "mixed3a_1x1_pre_relu_conv" # "mixed5b_pool_reduce_pre_relu_conv"
channel = 34 # 63
obj = objectives.channel(layer, channel)

diversity_alpha = 100 #6062.5 # super large objective which should hinder optimization
obj -= diversity_alpha * objectives.diversity(layer)

images, grad_norms = render.render_vis(model, obj, param_f, opt, thresholds=(20000,), verbose=True, interrupt_interval=1, interrupt_condition=grad_norm_interrupt)

In [None]:
import pickle

with open(f"{layer}__{channel}__gradnorms.pkl", "wb") as f:
    pickle.dump(grad_norms, f)

#with open(f"{layer}__{channel}__gradnorms.pkl", "rb") as f:
#    grad_norms = pickle.load(f)

In [None]:
# only show the relevant parts of the graph
xs = list(range(len(grad_norms)))
start = 10_000
grad_norms = grad_norms[start:]
xs = xs[start:]

plt.figure()
plt.title(f"Grad Norm for {layer}__{channel}, alpha {diversity_alpha}")
plt.xlabel("Training Steps")
plt.ylabel("Mean Gradient Norm")
#plt.yscale('log')
plt.ylim(1, 20)

plt.plot(xs, grad_norms)
plt.plot(xs, [8]*len(grad_norms))
plt.plot(xs, [7]*len(grad_norms))
#plt.savefig("grad_norms.png")
plt.show()
plt.close()

print(np.min(grad_norms))