In [1]:
import tensorflow as tf
import numpy as np
from path_explain.utils import set_up_environment

In [2]:
set_up_environment(visible_devices='2')

In [3]:
batch_size = 10
test_size  = (32, 64)
data = np.random.randn(batch_size, *test_size)

In [4]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=test_size))
model.add(tf.keras.layers.Conv1D(filters=16, kernel_size=3, strides=1, activation=tf.keras.activations.softplus))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=32, activation=tf.keras.activations.softplus))
model.add(tf.keras.layers.Dense(units=1))

In [5]:
batch_input = tf.convert_to_tensor(data)

In [41]:
with tf.GradientTape() as second_order_tape:
    second_order_tape.watch(batch_input)
    with tf.GradientTape() as first_order_tape:
        first_order_tape.watch(batch_input)
        output = model(batch_input)
    gradient = first_order_tape.gradient(output, batch_input)

hessian = second_order_tape.batch_jacobian(gradient, batch_input)

In [48]:
input_times_hessian = hessian * batch_input[:, :, :, tf.newaxis, tf.newaxis] * batch_input[:, tf.newaxis, tf.newaxis, :, :]
summed_hessian_from_jacobian = tf.reduce_sum(input_times_hessian, axis=(2, 4))

In [81]:
with tf.GradientTape(persistent=True) as second_order_tape:
    second_order_tape.watch(batch_input)
    with tf.GradientTape() as first_order_tape:
        first_order_tape.watch(batch_input)
        output = model(batch_input)
    gradient = first_order_tape.gradient(output, batch_input)
    gradient_times_input = batch_input * gradient
    summed_gradient_times_input = tf.reduce_sum(gradient_times_input, axis=2)
batch_summed_hessian = second_order_tape.batch_jacobian(summed_gradient_times_input, batch_input)
batch_summed_hessian_times_input = batch_summed_hessian * batch_input[:, tf.newaxis, :, :]
batch_summed_summed_hessian_times_input = tf.reduce_sum(batch_summed_hessian_times_input, axis=3)

In [85]:
batch_summed_hessian.shape

TensorShape([10, 32, 32, 64])

In [82]:
batch_summed_summed_hessian_times_input

<tf.Tensor: shape=(10, 32, 32), dtype=float64, numpy=
array([[[-1.05502390e-01, -1.70172225e-02, -5.44662466e-02, ...,
         -1.33994116e-03,  2.69511727e-03, -6.65550233e-04],
        [-1.70172197e-02,  9.86773240e-02,  5.66338611e-02, ...,
         -1.95719531e-03,  4.48519505e-04,  3.24663182e-05],
        [-5.44662534e-02,  5.66338586e-02, -3.22873205e-01, ...,
          5.97984305e-03,  1.52160315e-03,  3.36281641e-03],
        ...,
        [-1.33994129e-03, -1.95719384e-03,  5.97984041e-03, ...,
         -4.31650734e-02, -1.94025766e-02, -1.09749889e-02],
        [ 2.69511760e-03,  4.48518748e-04,  1.52160548e-03, ...,
         -1.94025774e-02,  6.06804017e-03, -4.90592735e-02],
        [-6.65548825e-04,  3.24665108e-05,  3.36281812e-03, ...,
         -1.09749909e-02, -4.90592671e-02,  1.43230623e-01]],

       [[ 4.98012997e-02,  1.09805460e-02,  2.28644936e-03, ...,
         -2.48018607e-04, -4.47010343e-04, -7.25157506e-04],
        [ 1.09805466e-02, -1.24665227e-02, -1.216