# Integrated Gradient Correlation
## Example on models predicting localized image statistics from natural images

### Download and unpack images from the NSD dataset (73k images)

- It requires a working Amazon S3 client (AWS)

```sh
python igcbenchmark/download_data.py
```

### Compute image std values

In [None]:
from igcbenchmark.image_1v0 import ImgSet

img_type = 'log2_Y'

imst = ImgSet(img_type).compute_std_values()

### Compute localized image statistics and std values

In [None]:
from igcbenchmark.imgstat_1v0 import ImgStatSet, ImgStatsExtractor

imgstat_set_names = (
    'comb_01-log2_Y-w_sum',            # Exp.A : scalar
    'ccat_01-log2_Y-max_mean',         # Exp.B : scalar
    'ccat_04-log2_Y-max_sim_rand',     # Exp.C : scalar
    'ccat_03-log2_Y-argmax_sim_rand',  # Exp.D : categorical
)
extract_kwargs = (
    None,
    None,
    {'probs': (0.0, 0.5)},
    {'permute': True})
img_size = 64

for imst_name, ext_kw in zip(imgstat_set_names, extract_kwargs):
    extractor = ImgStatsExtractor(imst_name, img_size)
    extractor.extract(batch_size=100, num_workers=8, extract_kwargs=ext_kw)

    imst_set = ImgStatSet(imst_name).load()
    imst_set.compute_std_values()

### Init dataset and model

In [None]:
import numpy as np
# For scalar outputs
from igcbenchmark.model_msk_stat_1v0 import Dataset, Model
# For categorical outputs
# from igcbenchmark.model_msk_stat_cat_1v0 import Dataset, Model

imgstat_set_name = 'comb_01-log2_Y-w_sum'
# imgstat_set_name = 'ccat_01-log2_Y-max_mean'
# imgstat_set_name = 'ccat_04-log2_Y-max_sim_rand'
# imgstat_set_name = 'ccat_03-log2_Y-argmax_sim_rand'
img_size = 64
img_aug = True
imst_aug_std = None
val_ratio = 0.1
seed = 100
device = 'cpu'
# device = 'cuda'
model_name = 'msk_stat_1v0_a000'

parameters = {
    'conv_stem_kernel': 2, 'conv_sizes': (16, 32, 64, 128, 256),
    'lin_sizes': (128, 16), 'learning_rate': 5e-5, 'seed': seed,
}

dataset = Dataset(
    imgstat_set_name, img_size, img_aug, imst_aug_std, val_ratio, seed)
model = Model(
    dataset, model_name, trainable=True, device=device, parameters=parameters)

### Train model

In [None]:
model.train(n_epoch=50, batch_size=64, num_workers=16)

### Compute R2 score

In [None]:
r2_score, _ = model.score(batch_size=100)
print('r2', np.round(np.mean(r2_score), 4))

### Compute IGC

In [None]:
_ = model.int_grad_corr(
    x_0=8, n_steps=64, x_batch_size=100, x_0_seed=100, check_error=True,
    num_workers=8)

### Check IGC error

In [None]:
_ = model.igc_error(f'int_grad_corr.npz', x_batch_size=100)

### Visualize IGC attributions

In [None]:
import matplotlib.pyplot as plt

igc = np.load(model.get_result_path('int_grad_corr.npz'))['data'][0]

fig, ax = plt.subplots(
    1, 1, figsize=(4, 4),
    gridspec_kw = {'top': 0.85, 'bottom': 0.01, 'left': 0.01, 'right': 0.99})

ax.axis('off')
ax.annotate(
    imgstat_set_name, (0.5, 1.05), va='bottom',
    ha='center', xycoords='axes fraction')
v_max = 1.25 * np.quantile(np.abs(igc), 0.99)
ax.imshow(
    igc, cmap='RdBu_r', vmin=-1.0*v_max, vmax=v_max,
    extent=(-0.5, 0.5, -0.5, 0.5), interpolation='nearest')
rect = plt.Rectangle(
    (-0.5, -0.5), 1.0, 1.0, fill=False, color='k', linewidth=1.0)
ax.add_patch(rect)