<h1 style="text-align: center; font-family: Verdana; font-size: 32px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; font-variant: small-caps; letter-spacing: 3px; color: #74d5dd; background-color: #ffffff;">Human Protein Atlas - Single Cell Classification</h1>
<h2 style="text-align: center; font-family: Verdana; font-size: 22px; font-style: normal; font-weight: bold; text-decoration: underline; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">Exploring Integrated Gradients & Background on Explainable AI</h2>
<h5 style="text-align: center; font-family: Verdana; font-size: 12px; font-style: normal; font-weight: bold; text-decoration: None; text-transform: none; letter-spacing: 1px; color: black; background-color: #ffffff;">CREATED BY: DARIEN SCHETTLER</h5>


<h2 style="font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;">TABLE OF CONTENTS</h2>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#imports">0&nbsp;&nbsp;&nbsp;&nbsp;IMPORTS</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#xai_background">1&nbsp;&nbsp;&nbsp;&nbsp;XAI BACKGROUND INFORMATION</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#ig_background">2&nbsp;&nbsp;&nbsp;&nbsp;IG BACKGROUND INFORMATION</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#setup">3&nbsp;&nbsp;&nbsp;&nbsp;SETUP</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#helper_functions">4&nbsp;&nbsp;&nbsp;&nbsp;HELPER FUNCTIONS</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#model">5&nbsp;&nbsp;&nbsp;&nbsp;LOAD AND DEMONSTRATE MODEL PREDICTION</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#calc_ig">6&nbsp;&nbsp;&nbsp;&nbsp;CALCULATE INTEGRATED GRADIENTS</a></h3>

---

<h3 style="text-indent: 10vw; font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;"><a href="#viz_ig">7&nbsp;&nbsp;&nbsp;&nbsp;VISUALIZE INTEGRATED GRADIENTS</a></h3>

---

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; background-color: #ffffff; color: navy;" id="imports">0&nbsp;&nbsp;IMPORTS</a>

In [None]:
print("\n... IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")

# Machine Learning and Data Science Imports
import tensorflow_addons as tfa; print(f"\t\t– TENSORFLOW ADDONS VERSION: {tfa.__version__}");
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");
import scipy; print(f"\t\t– SCIPY VERSION: {scipy.__version__}");

# Built In Imports
from kaggle_datasets import KaggleDatasets
from collections import Counter
from datetime import datetime
import multiprocessing
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import tqdm
import time
import gzip
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2

# PRESETS
LBL_NAMES = ["Nucleoplasm", "Nuclear Membrane", "Nucleoli", "Nucleoli Fibrillar Center", "Nuclear Speckles", "Nuclear Bodies", "Endoplasmic Reticulum", "Golgi Apparatus", "Intermediate Filaments", "Actin Filaments", "Microtubules", "Mitotic Spindle", "Centrosome", "Plasma Membrane", "Mitochondria", "Aggresome", "Cytosol", "Vesicles", "Negative"]
INT_2_STR = {x:LBL_NAMES[x] for x in np.arange(19)}
INT_2_STR_LOWER = {k:v.lower().replace(" ", "_") for k,v in INT_2_STR.items()}
STR_2_INT_LOWER = {v:k for k,v in INT_2_STR_LOWER.items()}
STR_2_INT = {v:k for k,v in INT_2_STR.items()}
FIG_FONT = dict(family="Helvetica, Arial", size=14, color="#7f7f7f")
LABEL_COLORS = [px.colors.label_rgb(px.colors.convert_to_RGB_255(x)) for x in sns.color_palette("Spectral", len(LBL_NAMES))]
LABEL_COL_MAP = {str(i):x for i,x in enumerate(LABEL_COLORS)}

print("\n\n... IMPORTS COMPLETE ...\n")

print("\n\n... TPU SETUP ...\n")

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    print("Running on TPU:", tpu.master())
except ValueError: # no TPU found, detect GPUs
    strategy = tf.distribute.get_strategy() # for GPU or multi-GPU machines
    print("\n... USING GPU ...\n")
    
N_REPLICAS = strategy.num_replicas_in_sync
print(f"... Number Of Accelerators: {N_REPLICAS} ...\n")

print("\n... TPU SETUP COMPLETE ...\n")

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="xai_background">1&nbsp;&nbsp;XAI BACKGROUND INFORMATION</a>

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.1  WHAT IS EXPLAINABLE AI - GENERAL INFO</h3>

---

For the purposes of this notebook and my explanation, I will be logically seperating explainable AI into two seperate branches. 

[**See this excerpt/paper that explains the branches in more detail.**](https://arxiv.org/pdf/1907.07374.pdf)

> "The two major categories presented here, namely perceptive interpretability and interpretability by mathematical structures, appear to present different polarities within the notion of interpretability. 
> 
> As an example for the difficulty with perceptive interpretability, when a visual evidence is given erroneously, the underlying mathematical structure may not seem to provide useful clues on the mistakes. 
> 
> On the other hand, a mathematical analysis of patterns may provide
information in high dimensions. They can only be easily perceived once the pattern is brought into lower dimensions, abstracting some fine-grained information we could not yet prove is not discriminative with measurable certainty."

[**<sup><sub>Tjoa, E., & Guan, C. (2019). A survey on explainable artificial intelligence (XAI): Towards
medical XAI. arXiv preprint arXiv:1907.0737</sub></sup>**](https://arxiv.org/pdf/1907.07374.pdf)
    
<br><br>

<b style="text-decoration: underline; font-family: Verdana;">PERCEPTIVE</b>
    
In short this is interpretability that can be observed by humans. Often the explanations arising through this branch are obvious to humans or already known.


<br>

<b style="text-decoration: underline; font-family: Verdana;">MATHEMATICAL</b>
    
In short this is interpretability that can only be observed by first applying mathematical manipulations to the data. An example technique that most are familiar with is clustering [**`(t-SNE)`**](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding)


<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.2  THE GOAL</h3>

---

**Pull the veil back on black-box machine learning models and help users understand how/why a model makes the decisions that it does. This can inform on how to improve the model as well as being useful for identifying things like bias and overfitting**

<img src="https://i.ibb.co/ZXdBQ4D/Screen-Shot-2020-07-07-at-10-24-16-AM.png">


<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.3  CURRENT XAI APPROACHES</h3>

---

| Algorithm                     	| Type         	| Description                                                                                                                                                                                                                                                                                                                                                                                                  	|
|:-------------------------------	|:--------------	|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------	|
| Integrated Gradients        	| Gradient     	| Approximates the integral of gradients along the path (straight line from baseline to input) sand multiplies with (input - baseline)                                                                                                                                                                                                                                                                         	|
| DeepLift                    	| Application  	| Explains differences in the non-linear activations' outputs in terms of the differences of the input from its corresponding reference.                                                                                                                                                                                                   	|
| DeepLiftSHAP                	| Gradient     	| An extension of DeepLift that approximates SHAP values.<br>For each input example it considers a distribution of baselines and computes the expected value of the attributions based on DeepLift algorithm across all input-baseline pairs.                                                                 	|
| GradientSHAP                	| Gradient     	| Approximates SHAP values based on the expected gradients.<br>It adds gaussian noise to each input example #samples times, selects a random point between each sample and randomly drawn baseline from baselines' distribution, computes the gradient for it and multiples it with (input - baseline).<br>Final SHAP values represent the expected values of gradients * (input - baseline) for each input example. 	|
| Input * Gradient              	| Gradient     	| Multiplies model inputs with the gradients of the model outputs w.r.t. those inputs.                                                                                                                                                                                                                                                                                                                         	|
| Saliency                     	| Gradient     	| The gradients of the output w.r.t. inputs.                                                                                                                                                                                                                                                                                                                                                                   	|
| Guided BackProp / DeconvNet 	| Gradient     	| Computes the gradients of the model outputs w.r.t. its inputs.<br>If there are any RELUs present in the model, their gradients will be overridden so that only positive gradients of the inputs (in case of Guided BackProp) and outputs (in case of deconvnet) are back-propagated.                                                                                                                            	|
| Guided GradCam                	| Gradient     	| Computes the element-wise product of Guided BackProp and up-sampled positive GradCam attributions.                                                                                                                                                                                                                                                                                                           	|
| LayerGradCam                  	| Gradient     	| Computes the gradients of model outputs w.r.t. selected input layer, averages them for each output channel and multiplies with the layer activations.                                                                                                                                                                                                                                                        	|
| Layer Internal Influence      	| Gradient     	| Approximates the integral of gradients along the path from baseline to inputs for selected input layer.                                                                                                                                                                                                                                                                                                      	|
| Layer Conductance            	| Gradient     	| Decomposes integrated gradients via chain rule.<br>It approximates the integral of gradients defined by a chain rule, described as the gradients of the output w.r.t. to the neurons multiplied by the gradients of the neurons w.r.t. the inputs, along the path from baseline to inputs.<br>Finally, the latter is multiplied by (input - baseline).                                                             	|
| Layer Gradient * Activation   	| Gradient     	| Computes element-wise product of layer activations and the gradient of the output w.r.t. that layer.                                                                                                                                                                                                                                                                                                         	|
| Layer Activation              	| -            	| Computes the inputs or outputs of selected layer.                                                                                                                                                                                                                                                                                                                                                            	|
| Feature Ablation            	| Perturbation 	| Assigns an importance score to each input feature based on the magnitude changes in model output or loss when those features are replaced by a baseline (usually zeros) based on an input feature mask.                                                                                                                                                                                                      	|
| Feature Permutation           	| Perturbation 	| Assigns an importance score to each input feature based on the magnitude changes in model output or loss when those features are permuted based on input feature mask.                                                                                                                                                                                                                                       	|
| Occlusion                     	| Perturbation 	| Assigns an importance score to each input feature based on the magnitude changes in model output when those features are replaced by a baseline (usually zeros) using rectangular sliding windows and sliding strides.<br>If a features is located in multiple hyper-rectangles the importance scores are averaged across those hyper-rectangles.                                                               	|
| Shapely Value                 	| Perturbation 	| Computes feature importances based on all permutations of all input features.<br>It adds each feature for each permutation one-by-one to the baseline and computes the magnitudes of output changes for each feature which are ultimately being averaged across all permutations to estimate final attribution score.                                                                                           	|
| Shapely Value Sampling        	| Perturbation 	| Similar to Shapely value, but instead of considering all feature permutations it considers only #samples random permutations.                                                                                                                                                                                                                                                                                	|
| NoiseTunnel                   	| -            	| Depends on the choice of above mentioned attribution algorithm                                                                                                                                                                                                                                                                                                                                                                        	|


<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.4  WHERE IS XAI NEEDED?</h3>

---

***Obviously things like the weakly-supervised tasks in this competition may require XAI***

XAI can be used for a wide range of things that we won't get into here (protecting against bias, protecting against overfitting, detecting features, etc.)

One other place XAI can be used is when working with black-box models. To understand what that is we will see the definitions and examples of the terms: **Transparent and Black-Box Models**:

<br>

<b style="text-decoration: underline; font-family: Verdana;">TRANSPARENT MODELS</b>

These are models/algorithms that are easily interpretable and **DO NOT** (generally) requre XAI. 

*Although occasionally post-hoc analysis is required or basic explainability tools.*

* Linear/Logistic Regression
* Decision Trees
* K-Nearest Neighbors
* Rule Based Learners
* General Additive Models
* Bayesian Models
    
<br>

<b style="text-decoration: underline; font-family: Verdana;">BLACK-BOX MODELS</b>

These are models/algorithms that are NOT easily interpretable and **DO** requre XAI. 

*This is not an exhaustive list of black-box models. It is simply the more common black-box models.*

* Tree Ensembles
* Support Vector Machines
* Multi-Layer Neural Network (MLPNN)
* Convolutional Neural Network (CNN)
* Recurrent Neural Network (RNN)


<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">1.4  Why Is XAI Needed: The Case For Growing Global AI Regulation</h3>

---

Many regulatory bodies have begun to encourage or enforce explainability in predictive algorithms used in the public domain.<br><br>**See below!**<br><sub>*(list was created roughly a year ago)*</sub>

- GDPR: Article 22 empowers individuals with the right to demand an explanation of how an
automated system made a decision that affects them.
- Algorithmic Accountability Act 2019: Requires companies to provide an assessment of the risks posed by
the automated decision system to the privacy or security and the risks that contribute to inaccurate, unfair,
biased, or discriminatory decisions impacting consumers
- California Consumer Privacy Act: Requires companies to rethink their approach to capturing,
storing, and sharing personal data to align with the new requirements by January 1, 2020.
- Washington Bill 1655: Establishes guidelines for the use of automated decision systems to protect
consumers, improve transparency, and create more market predictability.
- Massachusetts Bill H.2701: Establishes a commission on automated decision-making,
transparency, fairness, and individual rights.
- Illinois House Bill 3415: States predictive data analytics determining creditworthiness or hiring
decisions may not include information that correlates with the applicant race or zip code.

<a style="font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="ig_background">2&nbsp;&nbsp;INTEGRATED GRADIENTS BACKGROUND INFORMATION</a>

**This notebook will show how to implement Integrated Gradients (IG) for this competition.**

IG is an Explainable AI (XAI) technique introduced in the paper **[Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365)**

---

***LINKS***<br>
<sub>&nbsp;&nbsp;&nbsp;&nbsp;- **[Tensorflow - Google Colab This Is Heavily Based Off Of](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/interpretability/integrated_gradients.ipynb)**</sub><br>

---

**I**ntegrated **G**radients (**IG**) aims to explain the relationship between a model's predictions in terms of its features. It has many use cases including understanding feature importances, identifying data skew, and debugging model performance.

**IG** has become a popular interpretability technique due to its broad applicability to any differentiable model (e.g. images, text, structured data), ease of implementation, theoretical justifications, and computational efficiency relative to alternative approaches that allows it to scale to large networks and feature spaces such as images.

Go to this notebook to see the implementation of IG. In it, we will walk through an implementation of **IG** step-by-step to understand the pixel feature importances of an image classifier. 

---

As an example, consider this **[image](https://commons.wikimedia.org/wiki/File:San_Francisco_fireboat_showing_off.jpg)** of a fireboat spraying jets of water. 

You would classify this image as a **fireboat** and might highlight the pixels making up the **boat** and **water cannons** as being important to your decision. 

The model will also classify this image as a fireboat later on in this tutorial; however, does it highlight the same pixels as important when explaining its decision?

In the images below titled "**IG** Attribution Mask" and "Original + **IG** Mask Overlay" you can see that the model instead highlights (in purple) the pixels comprising the boat's **water cannons** and **jets of water** as being ***more important than the boat itself*** to its decision. 

How will the model generalize to new fireboats? What about fireboats without water jets? 

Read on to learn more about how **IG** works and how to apply **IG** to models to better understand the relationship between their predictions and underlying features.

![IG Example](https://www.tensorflow.org/tutorials/interpretability/images/IG_fireboat.png)

<a style="font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="setup">3&nbsp;&nbsp;NOTEBOOK SETUP</a>

In [None]:
# Define the root data directory
TRAIN_IMG_DIR = KaggleDatasets().get_gcs_path("hpa-512512")
LOCAL_DATA_DIR = "/kaggle/input/hpa-single-cell-image-classification"

# Capture all the relevant full image paths
TRAIN_IMG_PATHS = tf.io.gfile.glob(os.path.join(TRAIN_IMG_DIR, '*.png'))
MODEL_DIR = "/kaggle/input/hpa-xai-ig-tfrecords-tpu-training/model"

# Capture all the relevant full image paths
TRAIN_IMG_PATHS = tf.io.gfile.glob(os.path.join(TRAIN_IMG_DIR, '*.png'))

print(f"\n... Recall that 4 training images compose one example (R,G,B,Y) ...")
print(f"... \t– i.e. The first 4 training files are:")
for path in [x.rsplit('/',1)[1] for x in TRAIN_IMG_PATHS[:4]]: print(f"... \t\t– {path}")
print(f"\n... The number of training images is {len(TRAIN_IMG_PATHS)} i.e. {len(TRAIN_IMG_PATHS)//4} 4-channel images ...")

# Define paths to the relevant csv files &
# create the relevant dataframe objects
TRAIN_CSV = os.path.join(LOCAL_DATA_DIR, "train.csv")
train_df = pd.read_csv(TRAIN_CSV)
print("\n\nTRAIN DATAFRAME\n\n")
display(train_df.head(3))

SS_CSV = os.path.join(LOCAL_DATA_DIR, "sample_submission.csv")
ss_df = pd.read_csv(SS_CSV)
print("\n\nSAMPLE SUBMISSION DATAFRAME\n\n")
display(ss_df.head(3))

# Create Single-Label Dataframe
sl_df = train_df[train_df.Label.str.count("\|")==0].reset_index(drop=True)
display(sl_df.head(3))

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="helper_functions">4&nbsp;&nbsp;HELPER FUNCTIONS</a>

In [None]:
def decode_img(image_data, resize_to=(512,512)):
    image = tf.image.decode_png(image_data, channels=1)
    # explicit size needed for TPU
    image = tf.reshape(image, resize_to) 
    return tf.cast(image, tf.float32)


def load_image(img_id, img_dir, resize_to=(512,512), tpu_style=False):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    if not tpu_style:
        rgby = [
            np.asarray(Image.open(os.path.join(img_dir, img_id+f"_{c}.png")).resize(resize_to), np.uint8) \
            for c in ["red", "green", "blue", "yellow"]
        ]
        return np.stack(rgby, axis=-1)
    else:
        rgby = [
            decode_img(tf.io.read_file(os.path.join(img_dir, img_id+f"_{c}.png")), resize_to) \
            for c in ["red", "green", "blue", "yellow"]
        ]
        return tf.stack(rgby, axis=-1)

def plot_rgb(arr, figsize=(12,12)):
    """ Plot 3 Channel Microscopy Image """
    plt.figure(figsize=figsize)
    plt.title(f"RGB Composite Image", fontweight="bold")
    plt.imshow(arr)
    plt.axis(False)
    plt.show()
    
    
def convert_rgby_to_rgb(arr, boost_green=False):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
        boost_green (numpy array): Whether to boost the intensity of the green channel
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    if boost_green:
        rgb_arr[..., 0] = arr[..., 0]/1.25
        rgb_arr[..., 1] = np.clip(arr[..., 1]*2+arr[..., 3]/5, 0, 255)
        rgb_arr[..., 2] = arr[..., 2]/1.25
        rgb_arr = rgb_arr.astype(np.uint8)
    else:
        rgb_arr[..., 0] = arr[..., 0]
        rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
        rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr



def plot_ex(arr, figsize=(20,6), title=None, plot_merged=True, rgb_only=False):
    """ Plot 4 Channels Side by Side """
    if plot_merged and not rgb_only:
        n_images=5 
    elif plot_merged and rgb_only:
        n_images=4
    elif not plot_merged and rgb_only:
        n_images=4
    else:
        n_images=3
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        if not rgb_only:
            ch_arr = np.zeros_like(arr[..., :-1])        
        else:
            ch_arr = np.zeros_like(arr)
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            if rgb_only:
                continue
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        
        if rgb_only:
            plt.title(f"Merged RGB", fontweight="bold")
            plt.imshow(arr)
        else:
            plt.title(f"Merged RGBY into RGB", fontweight="bold")
            plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    
def flatten_list_of_lists(l_o_l):
    return [item for sublist in l_o_l for item in sublist]


def load_batch_of_images(df, id_list, img_dir, resize_to=(512,512), return_labels=True, tpu_style=False):
    if not return_labels:
        return np.stack([load_image(ID, img_dir, resize_to) for ID in id_list], axis=0)
    else:
        lbls = df[df.ID.isin(id_list)].Label.apply(lambda x: [int(l) for l in x.split("|")]).to_list()
        return np.stack([load_image(ID, img_dir, resize_to, tpu_style=tpu_style) for ID in id_list], axis=0), lbls
    
def plot_batch_of_images(img_batch, lbl_batch=None, pred_batch=None, n_cols=4, labels_as_strs=True, boost_green=True):
    n_imgs = img_batch.shape[0]
    if not lbl_batch:
        lbl_batch = [None,]*n_imgs
    if not pred_batch:
        pred_batch = [None,]*n_imgs
    
    plt.figure(figsize=(19, int(5.5*np.ceil(n_imgs/n_cols))))
    for i, (img, lbl, pred) in enumerate(zip(img_batch, lbl_batch, pred_batch)):
        plt.subplot(int(np.ceil(n_imgs/n_cols)), n_cols, i+1)
        if lbl or pred:
            title_str = ""
            if lbl:
                if labels_as_strs:
                    title_str+=f"GT LABEL: {[INT_2_STR[l] for l in lbl]}"
                else:
                    title_str+=f"GT LABEL: {lbl}"
            if pred:
                if labels_as_strs:
                    title_str+=f"\nPRED LABEL: {[INT_2_STR[p] for p in pred[0]]}"
                else:
                    title_str+=f"\nPRED LABEL: {pred[0]}"
            plt.title(title_str.strip("\n"), fontweight="bold")
        plt.imshow(convert_rgby_to_rgb(img, boost_green=boost_green))
        plt.axis(False)

    plt.tight_layout()
    plt.show()
    
def get_pred(model, img_batch, conf_thresh=0.3, drop_yellow=True):
    if drop_yellow:
        img_batch = img_batch[..., :-1]
    pred_batch = model.predict(img_batch)
    return [np.where(p>conf_thresh) for p in pred_batch]

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="model">5&nbsp;&nbsp;LOAD AND DEMONSTRATE MODEL PREDICTION</a>

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">5.1 LOAD THE TRAINED MODEL</h3>

---

In [None]:
with strategy.scope():
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    enet = tf.keras.models.load_model(MODEL_DIR, options=load_locally) # loading in Tensorflow's "SavedModel" format

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">5.2 SHOW MODEL PREDICTIONS</h3>

---

- Grab a random batch of single label IDs
- Get images and labels for batch
- Get predictions for batch at a given confidence threshold
- Visualize

In [None]:
RNDM_SNGL_LBL_IDS = sl_df.ID.sample(8).sort_index().to_list()
img_batch, lbl_batch = load_batch_of_images(sl_df, RNDM_SNGL_LBL_IDS, TRAIN_IMG_DIR, tpu_style=True)
pred_batch = get_pred(enet, img_batch, conf_thresh=0.4, drop_yellow=True)
plot_batch_of_images(img_batch, lbl_batch, pred_batch, n_cols=4, labels_as_strs=True, boost_green=True)

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="calc_ig">6&nbsp;&nbsp;CALCULATE INTEGRATED GRADIENTS</a>

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.0 IG Demo Setup</h3>

---

We need some images to show how integrated gradients works. We will take the first  image for every class (with a few replacements) to demonstrate how the technique will work.

In [None]:
# These are better images to demo the IG technique (more representative)
REPLACEMENTS = {
    "Nucleoplasm":"3e4b0862-bba2-11e8-b2b9-ac1f6b6435d0",
    "Mitochondria":"643f73a4-bb99-11e8-b2b9-ac1f6b6435d0",
    "Nuclear Membrane":"0881d08c-bb9b-11e8-b2b9-ac1f6b6435d0",
    "Actin Filaments":"5fb9edb4-bb99-11e8-b2b9-ac1f6b6435d0",
    "Centrosome":"ca883cf4-bb99-11e8-b2b9-ac1f6b6435d0",
}

unique_np = sl_df.drop_duplicates("Label").to_numpy()
unique_class_2_id = {INT_2_STR[int(unique_np[i][1])]:unique_np[i][0] for i in INT_2_STR.keys()}
unique_class_2_id.update(REPLACEMENTS)
unique_class_2_img = {k:load_image(v, TRAIN_IMG_DIR, tpu_style=True) for k,v in unique_class_2_id.items()}
single_demo_img = unique_class_2_img["Nucleoplasm"]
plot_batch_of_images(np.asarray(list(unique_class_2_img.values())), list(unique_class_2_img.keys()), pred_batch=None, n_cols=4, labels_as_strs=False, boost_green=True)

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.1 Back To Basics</h3>

---

Our model is a learned function that describes a mapping between the input feature space – image pixel values – and an output space defined by class probabilities (19) ranging from 0 to 1.

Early interpretability methods for neural networks assigned feature importance scores using gradients, which tell us which pixels have the steepest local relative to the model's prediction at a given point along the model's prediction function. 

However, gradients only describe ***local*** changes in the model's prediction function with respect to pixel values and do not fully describe the entire model prediction function. As the model fully ***learns*** the relationship between the range of an individual pixel and the correct class, the gradient for this pixel will ***saturate***, meaning become increasingly small and even go to zero. 

<br>

**Consider the simple model function and the plots below:**

---

**DESCRIPTION OF NEXT CELL OUTPUT**

---

* **LEFT PLOT**: The model's gradients for pixel **`x`** are positive between **`0.0`** and **`0.8`** but go to **`0.0`** between **`0.8`** and **`1.0`**. Pixel **`x`** clearly has a significant impact on pushing the model toward **`80%`** predicted probability on the true class.<br>
&nbsp;&nbsp;&nbsp;&nbsp;– *Does it make sense that pixel **`x`**'s importance is small or discontinuous?*

* **RIGHT PLOT**: The intuition behind IG is to accumulate pixel **`x`**'s local gradients and attribute its importance as a score for how much it adds or subtracts to the model's overall output class probability. **You can break down and compute IG in 3 parts:**<br>
&nbsp;&nbsp;&nbsp;&nbsp;1. interpolate small steps along a straight line in the feature space between **`0` (a baseline or starting point)** and **`1` (input pixel's value)** <br>
&nbsp;&nbsp;&nbsp;&nbsp;2. compute gradients at each step between the model's predictions with respect to each step<br>
&nbsp;&nbsp;&nbsp;&nbsp;3. approximate the integral between the baseline and input by accumulating (cumulative average) these local gradients.

---

<br>

To reinforce this intuition, you will walk through these 3 parts by applying IG to an **example image from every class of our dataset**.

In [None]:
def f(x):
    """A simplified model function."""
    return tf.where(x < 0.8, x, 0.8)

def interpolated_path(x):
    """A straight line path."""
    return tf.zeros_like(x)

x = tf.linspace(start=0.0, stop=1.0, num=6)
y = f(x)

plt.figure(figsize=(15, 6))
plt.subplot(1,2,1)
plt.plot(x, y, marker='o')
plt.title('Gradients saturate over F(x)', fontweight='bold')
plt.text(0.2, 0.5, 'Gradients > 0 = \n x is important')
plt.text(0.7, 0.85, 'Gradients = 0 \n x not important')
plt.yticks(tf.range(0, 1.5, 0.5))
plt.xticks(tf.range(0, 1.5, 0.5))
plt.ylabel('F(x) - model true class predicted probability')
plt.xlabel('x - (pixel value)')

plt.subplot(1,2,2)
plt.plot(x, y, marker='o')
plt.plot(x, interpolated_path(x), marker='>')
plt.title('IG intuition', fontweight='bold')
plt.text(0.25, 0.1, 'Accumulate gradients along path')
plt.ylabel('F(x) - model true class predicted probability')
plt.xlabel('x - (pixel value)')
plt.yticks(tf.range(0, 1.5, 0.5))
plt.xticks(tf.range(0, 1.5, 0.5))
plt.annotate('Baseline', xy=(0.0, 0.0), xytext=(0.0, 0.2),
             arrowprops=dict(facecolor='black', shrink=0.1))
plt.annotate('Input', xy=(1.0, 0.0), xytext=(1, 0.2),
             arrowprops=dict(facecolor='black', shrink=0.1))

plt.show()

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.2 Establish a Baseline</h3>

---

A baseline is an input image used as a starting point for calculating feature importance. 

Intuitively, you can think of the baseline's explanatory role as representing the impact of the absence of each pixel on a given prediction to contrast with its impact of each pixel on the same prediction when present in the input image. 

As a result, the choice of the baseline plays a central role in interpreting and visualizing pixel feature importances. The process of choosing a valid baseline can be a complicated procedure and may warrant it's own tutorial. For simplicity, you will use a black image whose pixel values are all zero.

Other choices you could experiment with include an all white image, or a random image, which you can create with **`tf.random.uniform(shape=(512,512,3), minval=0.0, maxval=1.0)`**

In [None]:
INPUT_SHAPE = (512,512,3)
baseline = tf.zeros(shape=INPUT_SHAPE)
white_baseline = tf.ones(shape=INPUT_SHAPE, dtype=tf.uint8)*255
random_baseline = tf.cast(tf.random.uniform(shape=INPUT_SHAPE, minval=0.0, maxval=1.0)*255, tf.uint8)

plt.figure(figsize=(15,5))

plt.subplot(1,3,1)
plt.imshow(baseline)
plt.title("Black Baseline (Our Choice)", fontweight="bold")
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(white_baseline)
plt.title("White Baseline (Possible Alternate -- Axis on for Display Purposes)")

plt.subplot(1,3,3)
plt.imshow(random_baseline)
plt.title("Random Baseline (Possible Alternate)")
plt.axis('off')

plt.tight_layout()
plt.show()

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.3 Unpack The IG Formulas Into Code</h3>

---

The formula for Integrated Gradients is as follows:

<center>

> $IntegratedGradients_{i}(x) ::= (x_{i} - x'_{i})\times\int_{\alpha=0}^1\frac{\partial F(x'+\alpha \times (x - x'))}{\partial x_i}{d\alpha}$

</center>

> 
> **where:**
> 
> * **$_{i}$ indicates the $_{i}th$ feature**
> * **$x$ is the input**
> * **$x'$ is the baseline**
> * **$\alpha$ is the interpolation constant to perturbe features by**


---


In practice, computing a definite integral is not always numerically possible and can be computationally costly, so you compute the following numerical approximation as follows:

<center>

> $IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\times\sum_{k=1}^{m}\frac{\partial F(x' + \frac{k}{m}\times(x - x'))}{\partial x_{i}} \times \frac{1}{m}$

</center>

<br>

> 
> **where:**
> * **$_{i}$ = feature (individual pixel)**
> * **$x$ = input (image tensor)**
> * **$x'$ = baseline (image tensor)**
> * **$k$ = scaled feature perturbation constant**
> * **$m$ = number of steps in the Riemann sum approximation of the integral** 
> * **$(x_{i}-x'_{i})$ = a term for the difference from the baseline.**<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- ***This is necessary to scale the integrated gradients and keep them in terms of the original image.***<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *The path from the baseline image to the input is in pixel space.*<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *Since with **IG** you are integrating in a straight line (linear transformation) this ends up being<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;roughly equivalent to the integral term of the derivative of the interpolated image function with respect to* **$\alpha$** *with enough steps.*<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *The integral sums each pixel's gradient times the change in the pixel along the path.*<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *It's simpler to implement this integration as uniform steps from one image to the other, substituting* **$x = (x0 + a(x1-x0))$.**<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *So the change of variables gives* **$dx = (x1-x0)da$**.<br>
> &nbsp;&nbsp;&nbsp;&nbsp;- *The* **$(x1-x0)$** *term is constant and is factored out of the integral.*

<h3 style="text-align: font-family: Verdana; font-size: 15px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: None; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.3.1 Linear Interpolation</h3>

---

First, we will generate a **[linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation)** between the baseline and the original image. We can think of interpolated images as small steps in the feature space between your baseline and input, represented by <b>$\alpha$</b> in the original equation.

We will first generate the interpolation for one image and visualize all the steps. Following that we will generate the interpolations for an image from each class and visualize only **`5`** of those steps.

In [None]:
TMP_ALPHAS = tf.linspace(start=0.0, stop=1.0, num=25)

def interpolate_images(baseline, image, alphas=None):
    if alphas is None:
        alphas = tf.linspace(start=0.0, stop=1.0, num=25)

    # type coercion
    _image = tf.cast(image, alphas.dtype)
    _baseline = tf.cast(baseline, alphas.dtype)
    
    if tf.math.reduce_max(_image)>1.:
        _image /= 255.
        
    if tf.math.reduce_max(_baseline)>1.:
        _baseline /= 255.
    
    # Give alphas, baseline, and input all 4 dimensions (b, w, h, c)
    alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]
    baseline_x = tf.expand_dims(_baseline, axis=0)
    
    input_x = tf.expand_dims(_image, axis=0)
    
    # Calculate delta
    delta = input_x - baseline_x

    # Create the 25 stepwise images between baseline and the original image
    # As alphas_x increases the contribution of the original image 
    # (represented by delta in our case) grows to become the original value
    #
    # This is in essence, stepping from pure black towards the original
    images = baseline_x + alphas_x*delta
    return images*255

# Let's test with a single class
interpolated_images_demo = interpolate_images(
    random_baseline, image=single_demo_img[..., :-1], alphas=TMP_ALPHAS,
)

plt.figure(figsize=(22, 22))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.title("alpha={:.4f}".format(TMP_ALPHAS[i].numpy()), fontweight="bold")
    plt.axis('off')
    plt.imshow(interpolated_images_demo[i]/255)
plt.tight_layout()
plt.show()

In [None]:
# Initialize
interpolated_images_by_class = {}

# Now for every class in the dataset
plt.figure(figsize=(18, int(len(INT_2_STR)*3.4)))
for i, (c, image) in tqdm(enumerate(unique_class_2_img.items())):
    interpolated_images = interpolate_images(baseline=random_baseline,
                                                      image=image[..., :-1],
                                                      alphas=TMP_ALPHAS)
    # Save for later
    interpolated_images_by_class[c]=interpolated_images

    # Visualize
    for j in range(6):
        plt.subplot(len(INT_2_STR), 6, i*6+(j+1))
        plt.title("{}\nAt alpha={:.2f}".format(c, TMP_ALPHAS[min(j*5, 24)].numpy()), fontweight="bold")
        plt.axis('off')
        plt.imshow(interpolated_images[min(j*5, 24)]/255)

plt.tight_layout()
plt.show()

<h3 style="text-align: font-family: Verdana; font-size: 15px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: None; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.3.2 Compute Gradients</h3>

---

Now let's take a look at how to calculate gradients in order to measure the relationship between changes to a feature and changes in the model's predictions. 

In the case of images, the gradient tells us which pixels have the strongest effect on the models predicted class probabilities.

<center>

> $IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\times\sum_{k=1}^{m}\frac{\overbrace{\partial F(\text{interpolated images})}^\text{compute gradients}}{\partial x_{i}} \times \frac{1}{m}$

</center>

<br>

> **where:**
> * **$F(...)$ is your model's prediction function**
> * **$\frac{\partial{F}}{\partial{x_i}}$ is the gradient (vector of partial derivatives $\partial$) of your model** **$F$** **'s prediction function relative to each feature $x_i$**

<br>

To compute the gradients we use **[`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape)**

**Let's compute the gradients for each image along the interpolation path with respect to the correct output.**<br>
&nbsp;&nbsp;&nbsp;&nbsp;- Recall that your model returns a **`(1, n_classes)`** shaped **`Tensor`** with logits that for each label<br>
&nbsp;&nbsp;&nbsp;&nbsp;- We need to pass the correct target class index to the **`compute_gradients`** function for your image.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- Note the output shape of **`(n_interpolated_images, img_height, img_width, RGB)`**<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;which gives us the gradient for every pixel of every image along the interpolation path. <br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- *You can think of these gradients as measuring the change in your model's predictions for each small step in the feature space.*

In [None]:
def compute_gradients(model, images, target_class_idx):
    if images.dtype != tf.float32:
        images = tf.cast(images, tf.float32)
    
    if tf.math.reduce_max(images)<=1.:
        images *= 255.
        
    with tf.GradientTape() as tape:
        tape.watch(images)
        target_pred = model(images)[:, target_class_idx]
        
    return tape.gradient(target_pred, images)

In [None]:
demo_path_gradients = compute_gradients(
    enet, 
    images=interpolated_images_by_class["Nucleoplasm"],
    target_class_idx=tf.constant(0)
)
print("Demo Path Gradients - Shape : {}".format(demo_path_gradients.shape))

path_gradients_by_class = {k:compute_gradients(enet, tf.cast(v, tf.float32), tf.constant(STR_2_INT[k])) for k,v in interpolated_images_by_class.items()}

<h3 style="text-align: font-family: Verdana; font-size: 15px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: None; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.3.3 Visualizing Gradient Saturation</h3>

---


Recall that the gradients we just calculated above describe ***local*** changes to our model's predicted probability for each available class and can ***saturate***.

These concepts are visualized using the gradients we calculated above.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- Each class requires 2 plots.

---

**DESCRIPTION OF NEXT CELL OUTPUT. THE DESCRIPTIONS ARE WHAT SHOULD HAPPEN. NOT NECESSARILY WHAT DOES**

---

* **LEFT PLOTS**: These plots shows how your model's confidence for a given class varies across alphas. <br>
&nbsp;&nbsp;&nbsp;&nbsp;- Notice how the gradients, or slope of the line, largely flattens or saturates before reaching a value of **`1.0`** at the max probability.

* **RIGHT PLOTS**: These plots shows the average gradients magnitudes over alpha more directly.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- Notice how the values sometimes sharply approach and even briefly dip below zero.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- In fact, the model ***learns*** the most from gradients at lower values of alpha before saturation occurs

---

To make sure that the impomrtant pixels for a given class are reflected as important in the respective prediction, we will continue on below to learn how to accumulate these gradients to accurately approximate how each pixel impacts the model's predicted probability score.

In [None]:
def viz_grad_saturation(model, interpolated_images, path_gradients, target_class):
    target_class_idx = STR_2_INT[target_class]
    
    if tf.math.reduce_max(interpolated_images)<=1.:
        interpolated_images *= 255.
    
    target_prob = model(interpolated_images)[:, target_class_idx]
    plt.figure(figsize=(15, 5))
    plt.suptitle("Gradient Visualization for the {} Class".format(target_class), fontsize=16, fontweight="bold")
    plt.subplot(1, 2, 1)
    plt.plot(TMP_ALPHAS, target_prob)
    plt.title('\nTarget class predicted probability over alpha', fontweight="bold")
    plt.ylabel('model p({} class)'.format(target_class))
    plt.xlabel('alpha')
    plt.ylim([-0.1, 1.1])

    plt.subplot(1, 2, 2)
    
    # Average across interpolation steps
    average_grads = tf.reduce_mean(path_gradients, axis=[1, 2, 3])
    
    # Normalize gradients to 0 to 1 scale. E.g. (x - min(x))/(max(x)-min(x))
    average_grads_norm = ((average_grads-tf.math.reduce_min(average_grads)) / 
                          (tf.math.reduce_max(average_grads)-tf.reduce_min(average_grads)))
    
    plt.plot(TMP_ALPHAS, average_grads_norm)
    plt.title('Average pixel gradients (normalized) over alpha', fontweight="bold")
    plt.ylabel('Average pixel gradients')
    plt.xlabel('alpha')
    plt.ylim([0, 1]);

    plt.show()

for (c, interpolated_images), path_gradients in zip(interpolated_images_by_class.items(), path_gradients_by_class.values()):
    viz_grad_saturation(enet, interpolated_images, path_gradients, c)

<h3 style="text-align: font-family: Verdana; font-size: 15px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: None; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.3.4 Accumulate Gradients (Integral Approximation)</h3>

---


There are many different ways we can go about computing the numerical approximation of an integral for **IG** with different tradeoffs in accuracy and convergence across varying functions.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- A popular class of methods is called **[Riemann sums](https://en.wikipedia.org/wiki/Riemann_sum)**.<br>
&nbsp;&nbsp;&nbsp;&nbsp;- Here, we will use the Trapezoidal rule

<center>

> $IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\times \overbrace{\sum_{k=1}^{m}}^\text{Sum m local gradients}
\text{gradients(interpolated images)} \times \overbrace{\frac{1}{m}}^\text{Divide by m steps}$

</center>

<br>

From the equation, we can see that we are summing over **`m`** gradients and dividing by **`m`** steps. We can implement the two operations together as an ***average of the local gradients of `m` interpolated predictions and input images***

In [None]:
def integral_approximation(gradients):
    # riemann_trapezoidal
    grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
    integrated_gradients = tf.math.reduce_mean(grads, axis=0)
    return integrated_gradients

ig_by_class = {k:integral_approximation(gradients=v) for k,v in path_gradients_by_class.items()}

<h3 style="text-align: font-family: Verdana; font-size: 20px; font-style: normal; font-weight: normal; text-decoration: none; text-transform: uppercase; letter-spacing: 2px; color: navy; background-color: #ffffff;">6.4 Putting It All Together</h3>

---

Now we will combine the previous general parts together into an **`IntegratedGradients`** function and utilize a **[@tf.function](https://www.tensorflow.org/guide/function)** decorator to compile it into a high performance callable Tensorflow graph. 

This is implemented as 5 smaller steps below:

<center>

> $IntegratedGrads^{approx}_{i}(x)::=\overbrace{(x_{i}-x'_{i})}^\text{5.}\times \overbrace{\sum_{k=1}^{m}}^\text{4.} \frac{\partial \overbrace{F(\overbrace{x' + \overbrace{\frac{k}{m}}^\text{1.}\times(x - x'))}^\text{2.}}^\text{3.}}{\partial x_{i}} \times \overbrace{\frac{1}{m}}^\text{4.}$

</center>

<br>

**Step 1.** Generate alphas $\alpha$<br>
**Step 2.** Generate interpolated images = $(x' + \frac{k}{m}\times(x - x'))$<br>
**Step 3.** Compute gradients between model $F$ output predictions with respect to input features = $\frac{\partial F(\text{interpolated path inputs})}{\partial x_{i}}$<br>
**Step 4.** Average integral approximation = $\sum_{k=1}^m \text{gradients} \times \frac{1}{m}$<br>
**Step 5.** Scale integrated gradients with respect to original image = $(x_{i}-x'_{i}) \times \text{integrated gradients}$. The reason this step is necessary is to make sure that the attribution values accumulated across multiple interpolated images are in the same units and faithfully represent the pixel importances on the original image.

---

**NOTE**

The integrated gradients paper suggests the number of steps to range between 20 to 300 depending upon the example (although in practice this can be higher in the 1,000s to accurately approximate the integral).

**We will use `128 steps` (4 batches of 32)**

In [None]:
# VERBOSE VERSION
# @tf.function(experimental_relax_shapes=False)
# def integrated_gradients(model, 
#                          baseline,
#                          image,
#                          target_class_idx,
#                          m_steps=128,
#                          batch_size=32):
    
#     t1 = time.time()
#     # Step 0. Format Check
#     if tf.math.reduce_max(image)<=1.:
#         image *= 255.
#     if tf.math.reduce_max(baseline)<=1.:
#         baseline *= 255.
#     print(f"STEP 0 TOOK {time.time()-t1:.3f} SECONDS")
    
#     t1 = time.time()
#     # Step 1. Generate alphas
#     alphas = tf.linspace(start=tf.constant(0, dtype=tf.float32), 
#                          stop=tf.constant(1, dtype=tf.float32), 
#                          num=m_steps)
    
#     # Accumulate gradients across batches
#     integrated_gradients = 0.0

#     # Batch alpha images
#     ds = tf.data.Dataset.from_tensor_slices(alphas).batch(batch_size)
#     print(f"STEP 1 TOOK {time.time()-t1:.3f} SECONDS")
    
#     print("\n----------------------- LOOPING -----------------------\n")
#     for i, batch in enumerate(ds):
#         t1 = time.time()
#         # Step 2. Generate interpolated images
#         batch_interpolated_inputs = interpolate_images(baseline=baseline/255,
#                                                        image=image/255,
#                                                        alphas=batch)
#         print(f"BATCH {i+1}, STEP 2, IN-LOOP, TOOK {time.time()-t1:.3f} SECONDS")
        
#         t1 = time.time()
#         # Step 3. Compute gradients between model outputs and interpolated inputs
#         batch_gradients = compute_gradients(model,
#                                             images=batch_interpolated_inputs,
#                                             target_class_idx=target_class_idx)
#         print(f"BATCH {i+1}, STEP 3, IN-LOOP, TOOK {time.time()-t1:.3f} SECONDS")
        
#         t1 = time.time()
#         # Step 4. Average integral approximation. Summing integrated gradients across batches.
#         integrated_gradients += integral_approximation(gradients=batch_gradients)
#         print(f"BATCH {i+1}, STEP 3, IN-LOOP, TOOK {time.time()-t1:.3f} SECONDS")

#     # Step 5. Scale integrated gradients with respect to input
#     scaled_integrated_gradients = (image - baseline) * integrated_gradients
#     return scaled_integrated_gradients

In [None]:
@tf.function(experimental_relax_shapes=False)
def integrated_gradients(model, 
                         baseline,
                         image,
                         target_class_idx,
                         m_steps=128,
                         batch_size=32):
    

    # Step 0. Format Check
    if tf.math.reduce_max(image)<=1.:
        image *= 255.
    if tf.math.reduce_max(baseline)<=1.:
        baseline *= 255.
    
    # Step 1. Generate alphas
    alphas = tf.linspace(start=tf.constant(0, dtype=tf.float32), 
                         stop=tf.constant(1, dtype=tf.float32), 
                         num=m_steps)

    # Accumulate gradients across batches
    integrated_gradients = 0.0

    # Batch alpha images
    ds = tf.data.Dataset.from_tensor_slices(alphas).batch(batch_size)

    for batch in ds:
        # Step 2. Generate interpolated images
        batch_interpolated_inputs = interpolate_images(baseline=baseline/255,
                                                       image=image/255,
                                                       alphas=batch)

        # Step 3. Compute gradients between model outputs and interpolated inputs
        batch_gradients = compute_gradients(model,
                                            images=batch_interpolated_inputs,
                                            target_class_idx=target_class_idx)

        # Step 4. Average integral approximation. Summing integrated gradients across batches.
        integrated_gradients += integral_approximation(gradients=batch_gradients)

    # Step 5. Scale integrated gradients with respect to input
    scaled_integrated_gradients = (image - baseline) * integrated_gradients
    return scaled_integrated_gradients

In [None]:
# If you just want to be able to play with the attributions... load them from file
# if os.path.isfile("/kaggle/input/ig_by_class.pickle"):
#     with open("/kaggle/input/ig_by_class.pickle", "rb") as input_file:
#         ig_attributions_by_class = pickle.load(input_file)
# else:

# Note we cast all input arguments as tensors to avoid retracing the graph
ig_attributions_by_class = {k:integrated_gradients(enet, 
                                                   tf.cast(baseline, tf.float32), 
                                                   image=tf.cast(v[..., :-1], tf.float32), 
                                                   target_class_idx=tf.constant(STR_2_INT[k])) \
                            for k,v in tqdm(unique_class_2_img.items(), total=len(INT_2_STR))}

In [None]:
print("\n... SAVING TO FILE ...\n")
try:
    with open("/kaggle/working/ig_by_class.pickle", "wb") as output_file:
        pickle.dump(ig_attributions_by_class, output_file)
except:
    pass

<a style="text-align: font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: none; text-transform: none; letter-spacing: 3px; color: navy; background-color: #ffffff;" id="viz_ig">7&nbsp;&nbsp;VISUALIZE INTEGRATED GRADIENTS</a>

In [None]:
def plot_img_attributions(attributions, baseline, image, class_label, overlay_alpha=0.5, save_fig=True, rescale_power=1):
    def _min_max_tensor(tensor, exp=1):
        return tf.math.divide(tf.subtract(tensor, tf.reduce_min(tensor)), 
                               tf.subtract(tf.reduce_max(tensor), tf.reduce_min(tensor)))**exp
        
    def _make_special_cmap():
        # Make special colour maps
        pos_cmap = sns.dark_palette("red")
        pos_cmap.insert(0, (0.,0.,0.))
        pos_cmap = ListedColormap(pos_cmap)

        neg_cmap = sns.dark_palette("blue")
        neg_cmap.insert(0, (0.,0.,0.))
        neg_cmap = ListedColormap(neg_cmap)

        return pos_cmap, neg_cmap
    
    pos_cmap, neg_cmap = _make_special_cmap()

    # Sum of the attributions across color channels for visualization.
    # The attribution mask shape is a grayscale image with height and width
    # equal to the original image.
    
    # full_attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1) 
    pos_attribution_mask = \
        tf.reduce_sum(tf.math.abs(tf.clip_by_value(attributions, 0, 10000)), axis=-1)
    neg_attribution_mask = \
        tf.reduce_sum(tf.math.abs(tf.clip_by_value(attributions, -10000, 0)), axis=-1)

    pos_attribution_mask = _min_max_tensor(pos_attribution_mask, 1/rescale_power)
    neg_attribution_mask = _min_max_tensor(neg_attribution_mask, 1/rescale_power)

    combined_masks = np.zeros_like(attributions)
    combined_masks[:, :, 0] = pos_attribution_mask
    combined_masks[:, :, 2] = neg_attribution_mask

    # Plotting
    plt.figure(figsize=(19,16))

    plt.suptitle("Attribution Visualization for the {} Class\n\n\n" \
                 "".format(class_label), fontsize=16, fontweight="bold")
    
    plt.subplot(2,3,1)
    plt.title('\nBaseline image', fontweight="bold")
    plt.imshow(baseline)
    plt.axis('off')
    
    plt.subplot(2,3,2)
    plt.title('\nPositive Attribution Mask', fontweight="bold")
    plt.imshow(pos_attribution_mask, cmap=pos_cmap)
    plt.axis('off')

    plt.subplot(2,3,3)
    plt.title('\nNegative Attribution Mask', fontweight="bold")
    plt.imshow(neg_attribution_mask, cmap=neg_cmap)
    plt.axis('off')

    plt.subplot(2,3,4)
    plt.title('Original image', fontweight="bold")
    plt.imshow(image/255)
    plt.axis('off')

    plt.subplot(2,3,5)
    plt.title('Full Attribution mask', fontweight="bold")
    plt.imshow(combined_masks)
    plt.axis('off')

    plt.subplot(2,3,6)
    plt.title('Overlay', fontweight="bold")
    # plt.imshow(np.sum(combined_masks, axis=2), cmap="gray")
    plt.imshow(combined_masks)
    plt.imshow(image/255, alpha=overlay_alpha)
    plt.axis('off')
    
    if save_fig:
        plt.savefig(f"/kaggle/working/{c}.jpg" ,dpi=400)
    
    plt.tight_layout()
    plt.show()
    
    return combined_masks

In [None]:
for (c, attributions), image in tqdm(zip(ig_attributions_by_class.items(), unique_class_2_img.values())):
    plot_img_attributions(attributions, tf.cast(baseline, tf.float32), image=tf.cast(image[..., :-1], tf.float32), class_label=c, save_fig=True, rescale_power=1)

In [None]:
for (c, attributions), image in tqdm(zip(ig_attributions_by_class.items(), unique_class_2_img.values())):
    plot_img_attributions(attributions, tf.cast(baseline, tf.float32), image=tf.cast(image[..., :-1], tf.float32), class_label=c, save_fig=True, rescale_power=2)

In [None]:
for (c, attributions), image in tqdm(zip(ig_attributions_by_class.items(), unique_class_2_img.values())):
    plot_img_attributions(attributions, tf.cast(baseline, tf.float32), image=tf.cast(image[..., :-1], tf.float32), class_label=c, save_fig=True, rescale_power=3)