In [1]:
import tensorflow as tf
import matplotlib.pylab as plt
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import cv2

In [2]:
tf_model = tf.keras.Sequential([
    hub.KerasLayer(
        name='inception_v1',
        handle='https://tfhub.dev/google/imagenet/inception_v1/classification/4',
        trainable=False),
])

tf_model.build([None, 224, 224, 3])

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [12]:
def interpolate_images(baseline,
                       image,
                       alphas):
  alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]
  baseline_x = tf.expand_dims(baseline, axis=0)
  input_x = tf.expand_dims(image, axis=0)
  print(input_x)
  print(baseline_x)
  delta = input_x - baseline_x
  images = baseline_x +  alphas_x * delta
  return images

In [4]:
plt.tight_layout()

def compute_gradients(tf_model, images, target_class_idx):
    with tf.GradientTape() as tape:
        tape.watch(images)
        logits = tf_model(images)
        probs = tf.nn.softmax(logits, axis=-1)[:, target_class_idx]
    return tape.gradient(probs, images)

<Figure size 640x480 with 0 Axes>

In [5]:
@tf.function
def one_batch(tf_model, baseline, image, alpha_batch, target_class_idx):
    # Generate interpolated inputs between baseline and input.
    interpolated_path_input_batch = interpolate_images(baseline=baseline,
                                                       image=image,
                                                       alphas=alpha_batch)

    # Compute gradients between model outputs and interpolated inputs.
    gradient_batch = compute_gradients(tf_model=tf_model, images=interpolated_path_input_batch,
                                       target_class_idx=target_class_idx)
    return gradient_batch

In [6]:
def integral_approximation(gradients):
  # riemann_trapezoidal
  grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
  integrated_gradients = tf.math.reduce_mean(grads, axis=0)
  return integrated_gradients

In [7]:
def integrated_gradients(tf_model, baseline,
                         image,
                         target_class_idx,
                         m_steps=50,
                         batch_size=32):
  # Generate alphas.
  alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)

  # Collect gradients.
  gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)

  # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.
  for alpha in tf.range(0, len(alphas), batch_size):
    from_ = alpha
    to = tf.minimum(from_ + batch_size, len(alphas))
    alpha_batch = alphas[from_:to]

    interpolated_path_input_batch = interpolate_images(baseline=baseline, image=image, alphas=alpha_batch)

    gradient_batch = compute_gradients(tf_model=tf_model, images=interpolated_path_input_batch, target_class_idx=target_class_idx)
    gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch)

  # Concatenate path gradients together row-wise into single tensor.
  total_gradients = gradient_batches.stack()

  # Integral approximation through averaging gradients.
  avg_gradients = integral_approximation(gradients=total_gradients)

  # Scale integrated gradients with respect to input.
  integrated_gradients = (image - baseline) * avg_gradients

  return integrated_gradients

In [15]:
def plot_img_attributions(baseline,
                          image,
                          target_class_idx,
                          m_steps=50,
                          cmap=None,
                          overlay_alpha=0.4):

  integrated_gradients(tf_model=tf_model,baseline=baseline, image=image, target_class_idx=555)

  attributions = integrated_gradients(baseline=baseline,
                                      image=image,
                                      target_class_idx=target_class_idx,
                                      m_steps=m_steps)

  # Sum of the attributions across color channels for visualization.
  # The attribution mask shape is a grayscale image with height and width
  # equal to the original image.
  attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)

  fig, axs = plt.subplots(nrows=2, ncols=2, squeeze=False, figsize=(8, 8))

  axs[0, 0].set_title('Baseline image')
  axs[0, 0].imshow(baseline)
  axs[0, 0].axis('off')

  axs[0, 1].set_title('Original image')
  axs[0, 1].imshow(image)
  axs[0, 1].axis('off')

  axs[1, 0].set_title('Attribution mask')
  axs[1, 0].imshow(attribution_mask, cmap=cmap)
  axs[1, 0].axis('off')

  axs[1, 1].set_title('Overlay')
  axs[1, 1].imshow(attribution_mask, cmap=cmap)
  axs[1, 1].imshow(image, alpha=overlay_alpha)
  axs[1, 1].axis('off')

  plt.tight_layout()
  return fig


In [None]:
import cv2

baseline = tf.zeros(shape=(1080,1920,3))
video = cv2.VideoCapture('video.mp4')
fps = video.get(cv2.CAP_PROP_FPS)
print('frames per second =',fps)
minutes = 0
seconds = 28
frame_id = int(fps*(minutes*60 + seconds))
print('frame id =',frame_id)
video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
ret, frame = video.read()
print(frame.shape)
cv2.imshow('frame', frame); cv2.waitKey(0)
cv2.imwrite('my_video_frame.png', frame)
image = tf.image.convert_image_dtype(frame, dtype=tf.float32, saturate=False, name=None)

_ = plot_img_attributions(image=image,
                          baseline=baseline,
                          target_class_idx=555,
                          m_steps=240,
                          cmap=plt.cm.inferno,
                          overlay_alpha=0.4)


frames per second = 23.976023976023978
frame id = 671
(1080, 1920, 3)
tf.Tensor(
[[[[0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   ...
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]]

  [[0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   ...
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]]

  [[0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   ...
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]
   [0.01176471 0.03137255 0.03137255]]

  ...

  [[0.00784314 0.01568628 0.02352941]
   [0.00784314 0.01568628 0.02352941]
   [0.00784314 0.01568628 0.02352941]
   ...
   [0.00784314 0.01568628 0.02352941]
   [0.00784314 0.01568628 0.02352941]
   [