In [0]:
%tensorflow_version 2.x

In [0]:
import matplotlib as mpl
import numpy as np
import tensorflow as tf

from IPython.display import clear_output
from matplotlib import pyplot as plt
from tensorflow.keras.preprocessing import image

In [0]:
url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg'

In [0]:
def download(url, target_size=None):
  name = url.split('/')[-1]
  image_path = tf.keras.utils.get_file(name, origin=url)
  return tf.keras.preprocessing.image.load_img(image_path, target_size)

def show(img):
  plt.figure(figsize=(8,8))
  plt.grid(False)
  plt.axis('off')
  plt.imshow(img)
  plt.show()

original_img = download(url, target_size=[225, 375])
original_img = np.array(original_img)
show(original_img)

In [0]:
def preprocess(img):
  """ Convert RGB values from [0, 255] to [-1, 1] """
  img = tf.cast(img, tf.float32)
  img /= 128.0
  img -= 1.
  return img

def unprocess(img):
  """ Undo the preprocessing above """
  img = 255 * (img + 1.0) / 2.0
  return tf.cast(img, tf.uint8)

In [0]:
conv_base = tf.keras.applications.InceptionV3(weights='imagenet', 
                                              include_top=False)

In [0]:
names = ['mixed2', 'mixed3', 'mixed4', 'mixed5']
layers = [conv_base.get_layer(name).output for name in names]
model = tf.keras.Model(inputs=conv_base.input, outputs=layers)

In [0]:
def calc_loss(img):
  img_batch = tf.expand_dims(img, axis=0)
  layer_activations = model(img_batch)
  losses = [tf.math.reduce_mean(act) for act in layer_activations]
  return tf.reduce_sum(losses)

In [0]:
@tf.function
def step(img, lr=0.001):
  with tf.GradientTape() as tape:
    loss = calc_loss(img)

  gradients = tape.gradient(loss, img)
  gradients /= tf.math.reduce_std(gradients) + 1e-8 

  img.assign_add(gradients * lr)
  img.assign(tf.clip_by_value(img, -1, 1))

In [0]:
img = tf.Variable(preprocess(original_img))

steps = 1000
for i in range(steps):
  step(img)
  if i % 200 == 0:
    clear_output(wait=True)
    print ("Step {}".format(i))
    show(unprocess(img.numpy()))

clear_output(wait=True)
show(unprocess(img.numpy()))