In [32]:
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr
from jaxtyping import Float, Array, Key
import diffrax
import equinox as eqx
import pandas as pd
import optax

from models.AutoregressiveCDE import AutoregressiveCDE
from utils import train_utils, integrated_gradients, plots

This is notebook is supposed to check whether Integrated Gradients is correctly implemented together with standardization.
More specifically, in the Integrated Gradients paper it is mentioned that a correct implementation will always fulfill

\sum_{i=1}^{n} IG_i(x) = F(x) - F(x')

where IG_i is the integraded gradient at the i-th input, x is the original input vector and x' is the baseline vector

In [45]:
model = eqx.tree_deserialise_leaves("serialised_models/autoregressive_cde_14_day_ahead.eqx", like=AutoregressiveCDE(data_size=3, hidden_size=3, width_size=64, depth=3, key=jr.key(5678)))
jax.config.update('jax_enable_x64', True)


In [46]:
fname = 'data/clipped_data.csv'
df = pd.read_csv(fname)
real_world_ys = jnp.array(df['new_cases'].values) 
real_world_ys = jnp.log10(1+real_world_ys)
ts = jnp.linspace(0,1,100)

def integrated_gradients_all_outputs(model: callable,
                                     control_until, 
                                     predict_until,
                                     ys: jnp.ndarray,
                                     baseline: jnp.ndarray,
                                     steps: int,
                                     standardize_baseline: bool):

    ts = jnp.linspace(0, 1, 100)

    ys, mean, std = train_utils.standardize(ys)

    if(standardize_baseline):
        baseline = (baseline - mean)/std

    alphas = jnp.linspace(0.0, 1.0, steps)
    interpolated_inputs = baseline + alphas[:, None] * (ys - baseline)

    def wrapper(input):
        """
            To preserve Completeness we must treat the inverse log transforms as part of the function being differentiated
        """
        return train_utils.inverse_log_transform(model(ts, ys=input[:control_until], control_until=control_until, saveat=ts, train_until=predict_until) * std + mean)
    jac_fn = jax.jacrev(wrapper)     
    jacobian_path = jax.vmap(jac_fn)(interpolated_inputs)   
    avg_jacobian = jacobian_path.mean(axis=0)              


    ig = avg_jacobian * (ys - baseline)[None, :]
    
    return ig

In [47]:
"""
Verify Completeness Axiom for non-standardized baseline
"""
baseline = jnp.zeros(100)
ig = integrated_gradients_all_outputs(model, 10, 24, real_world_ys[:10], baseline[:10], 400, False)


ys_std, mean, std = train_utils.standardize(real_world_ys[:10])


f = model(ts, ys=ys_std[:10], control_until=10, saveat=ts, train_until=24)
f = train_utils.inverse_log_transform(train_utils.destandardize(f, mean, std))

f_bar = model(ts, ys=baseline[:10], control_until=10, saveat=ts, train_until=24)
f_bar = train_utils.inverse_log_transform(train_utils.destandardize(f_bar, mean, std))


ig_sum = jnp.sum(ig, axis=-1)
output_difference = f - f_bar

print(ig_sum)
print(output_difference)

assert(jnp.allclose(output_difference[:10], ig_sum[:10], atol=10e-2))




[-43.87084203 -46.91622536 -31.6649515   -8.99878639  29.54805714
  85.86941838 160.34558679 159.16836789 237.79821607 199.13648506
 167.8756028  141.96125513 120.05665005 101.25275285  84.90866145
  70.55855983  57.85490565  46.53237693  36.38423892  27.24642141
  18.98653722  11.49615447   4.68526205  -1.52175701   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.         

In [None]:
"""
Verify Completeness Axiom for standardized baseline. 

This passes only with more integration steps and a higher tolerance,
probably due to accumulation of floating point errors with the additional
standardization and destandardization.
"""

baseline = jnp.zeros(100)
ig = integrated_gradients_all_outputs(model, 10, 24, real_world_ys[:10], baseline[:10], 2000, True)


ys_std, mean, std = train_utils.standardize(real_world_ys[:10])

baseline = (baseline - mean)/std


f = model(ts, ys=ys_std[:10], control_until=10, saveat=ts, train_until=24)
f = train_utils.inverse_log_transform(train_utils.destandardize(f, mean, std))

f_bar = model(ts, ys=baseline[:10], control_until=10, saveat=ts, train_until=24)
f_bar = train_utils.inverse_log_transform(train_utils.destandardize(f_bar, mean, std))


ig_sum = jnp.sum(ig, axis=-1)
output_difference = f - f_bar

print(ig_sum)
print(output_difference)

assert(jnp.allclose(output_difference[:10], ig_sum[:10], atol=10e-2, rtol=10e-2))

[ 33.01436421  30.04210934  45.37130874  68.12645263 106.77729541
 163.21062442 237.79340792 236.68968063 315.40692365 276.81528784
 245.69542818 219.98724554 198.34848705 179.86477962 163.88989895
 149.9526604  137.70006558 126.86122007 117.22367807 108.61750235
 100.9042689   93.96932837  87.71626365  82.06285823   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.         