# Example notebook for Image Classification (with imagenet-mini)

In this tutorial, You will see how You can use explainable algorithms to study pre-trained model decision. We will take a pre-trained model, sample images and run several explainable methods.

### Setup 

#### Imports

First we have to import all necessary libraries.

In [None]:
# import necessary libraries
import os
import torch
from torch.utils.data import DataLoader
import torchvision

from foxai.context_manager import FoXaiExplainer, ExplainerWithParams, CVClassificationExplainers
from foxai.visualizer import mean_channels_visualization

Install missing libraries required by `YOLOv5` that are not part of `foxai` package.

In [None]:
!pip install scipy opencv-python seaborn ultralytics

Configure `CUDA_LAUNCH_BLOCKING=1` to prevent issues with `CUDA` while running GPU-accelerated computations in notebook.

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

#### Downloading missing models

Download `YOLOv5` and `ImageNet.yaml` files from https://github.com/ultralytics/yolov5 if not present in local storage.

In [None]:
# check if YOLOv5 model and ImageNet.yaml files are present at local storage and if they are not download them
![ ! -f "yolov5s-cls.pt" ] && wget https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s-cls.pt
![ ! -f "ImageNet.yaml" ] && wget https://raw.githubusercontent.com/ultralytics/yolov5/master/data/ImageNet.yaml

#### Define custom functions

Define custom function to visualize figures

In [None]:
import matplotlib.pyplot as plt

# function to enable displaying matplotlib Figures in notebooks
def show_figure(fig): 
    dummy = plt.figure()
    new_manager = dummy.canvas.manager
    new_manager.canvas.figure = fig
    new_manager.set_window_title("Test")
    fig.set_canvas(new_manager.canvas)
    return dummy

Define function that will load model, list of labels and transformation function of a desired model. Currently we support, in this notebook, only a few models: `VGG11`, `ResNet50`, `ViT`, `MobileNetV3` and `YOLOv5`. You can easilly add new models from `torchvision` model zoo and even define Your own model.

In [None]:
import yaml
from yaml.loader import SafeLoader
from torchvision.transforms._presets import ImageClassification
from typing import Tuple, List


def load_model(
    model_name: str,
) -> Tuple[torch.nn.Module, List[str], ImageClassification]:
    """Load model, label list and transformation function used in data preprocessing.

    Args:
        model_name: Model name. Recognized models are: `vgg11`, `resent50`, `yolov5`,
            `vit` and `mobilenetv3`.

    Raises:
        ValueError: raised if provided model name that is not supported.

    Returns:
        Tuple of model, list of labels and transformation function.
    """
    # normalize model name to match recognized models
    model_name_normalized: str = model_name.lower().strip()
    if model_name_normalized == "yolov5":
        # load YOLOv5 from torch Hub according to https://github.com/ultralytics/yolov5
        model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5s-cls.pt')

        # apply transformations just like in MobileNetV3
        transform = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1.transforms()
        
        # load YOLOv5 configuration
        with open("ImageNet.yaml") as file:
            data = yaml.load(file, Loader=SafeLoader)

        # and get only class names
        categories = list(data["names"].values())
    elif model_name_normalized == "vgg11":
        weights = torchvision.models.VGG11_Weights.IMAGENET1K_V1

        # load model from torchvision model zoo
        model = torchvision.models.vgg11(weights=weights)

        # get class names
        categories = weights.meta["categories"]
        transform = weights.transforms()
    elif model_name_normalized == "vit":
        weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1

        # load model from torchvision model zoo
        model = torchvision.models.vit_b_16(weights=weights)

        # get class names
        categories = weights.meta["categories"]
        transform = weights.transforms()
    elif model_name_normalized == "resnet50":
        weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1

        # load model from torchvision model zoo
        model = torchvision.models.resnet50(weights=weights)

        # get class names
        categories = weights.meta["categories"]
        transform = weights.transforms()
    elif model_name_normalized == "mobilenetv3":
        weights = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1

        # load model from torchvision model zoo
        model = torchvision.models.mobilenet_v3_small(weights=weights)

        # get class names
        categories = weights.meta["categories"]
        transform = weights.transforms()
    else:
        raise ValueError(f"Unrecognized model name: {model_name}")

    return model, categories, transform

### Configuration

Cell below contains configuration of this notebook. We have defined max number of samples to be saved in artifact directory, path to `ImageNet-Mini` dataset downloaded from [Kaggle](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000), name of the model, batch_size and device to be used.

In [None]:
batch_size: int = 1
max_samples_explained: int = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model_name: str = "yolov5"

# define directory where explanation artifacts will be stored
artifact_dir: str = f"artifacts/{model_name}/"

# `data_dir` variable contains path to dataset downloaded from https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000.
# You have to register in Kaggle to be able to download this dataset.
data_dir: str = "/home/user/Downloads/imagenet-mini"


### Loading the model

Load specified model, put it in evaluation mode, place it on specified device, download and preprocess `ImageNet-Mini` dataset. Trasformation function is used to match training dataset preprocessing steps.

In [None]:
# load model, classes and transformation function
model, categories, transform = load_model(model_name=model_name)

# put model in evaluation mode
model.eval()

# place model on specified device (CPU or GPU)
model.to(device)

# load test dataset - ImageNet-Mini downloaded from Kaggle: https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000
imagenet_val = torchvision.datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
val_dataloader = DataLoader(imagenet_val, batch_size=batch_size)

In [None]:
# instruct notebook to display figures inline
%matplotlib inline

## Data sample

Let's see how images from `ImageNet-Mini` looks like. We will display first few samples of dataset. In the following steps we will use them to explain model predictions using different explainable algorithms.

In [None]:
from matplotlib import pyplot as plt
from foxai.array_utils import convert_standardized_float_to_uint8, standardize_array

counter: int = 0
    
# create subplot with given number of samples to display
fig, axes = plt.subplots(max_samples_explained, 1, figsize=(25, 25))

# iterate over dataloader
for batch in val_dataloader:
    for sample, label in zip(*batch):
        # change image shape from (C X H X W) to (H X W X C)
        # where C stands for colour, X is height and W is width dimension
#         sample_np = sample.permute((1, 2, 0)).numpy().astype(float)
        sample_np = sample.permute((1, 2, 0)).numpy().astype(float)
        
        # set title
        axes.flat[counter].set_title(f"Label: {categories[label.item()]}")
        
        # disable visualizing X and Y axes
        axes.flat[counter].get_xaxis().set_visible(False)
        axes.flat[counter].get_yaxis().set_visible(False)

        # convert image from float to uint8 and display it
        axes.flat[counter].imshow(convert_standardized_float_to_uint8(standardize_array(sample_np.astype(float))))
        counter += 1

        if counter >= max_samples_explained:
            break

    if counter >= max_samples_explained:
        break

## Demo for general algorithms 

### Choosing foxai explainers (general algorithms)

Define list of explainers from `foxai` package You want to use on specified model. Full list of supported explainers can be found at definition of `Explainers` enum class.

In [None]:
# define list of explainers we want to use
# full list of supported explainers is present in `Explainers` enum class.
explainer_list = [
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_GRADIENT_SHAP_EXPLAINER),
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INPUT_X_GRADIENT_EXPLAINER),
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INTEGRATED_GRADIENTS_EXPLAINER),
]

### Explaining the predictions (general algorithms)

Iterate over dataset and explain predictions given by selected model using all specified CVClassificationExplainers. It could take a long time, depending on number of selected explainers and number of samples to explain. During this process new artifacts will be saved in artifact directory.

In [None]:
sample: torch.Tensor
label: int

sample_counter: int = 0
    
# iterate over dataloader
for sample_batch in val_dataloader:
    sample_list, label_list = sample_batch
    # iterate over all samples in batch
    for sample, label in zip(sample_list, label_list):
        # add batch size dimension to the data sample
        # input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)
        input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)
        category_name = categories[label.item()]
        # move it to specified device
        with FoXaiExplainer(
            model=model,
            explainers=explainer_list,
            target=label,
        ) as xai_model:
            # calculate attributes for every explainer
            _, attributes_dict = xai_model(input_data)

        for key, value in attributes_dict.items():
            # create directory for every explainer artifacts
            artifact_explainer_dir = os.path.join(artifact_dir, key)
            if not os.path.exists(artifact_explainer_dir):
                os.makedirs(artifact_explainer_dir)

            # create figure from attributes and original image           
            figure = mean_channels_visualization(attributions=value[0], transformed_img=sample, title= f"Mean of channels ({key})")

            # save figure to artifact directory
            figure.savefig(os.path.join(artifact_explainer_dir, f"artifact_{sample_counter}_{category_name}.png"))
            show_figure(figure)
            
        sample_counter += 1
        # if we processed desired number of samples break the loop
        if sample_counter > max_samples_explained:
            break

    # if we processed desired number of samples break the loop
    if sample_counter > max_samples_explained:
        break

## Demo for layer specific algorithms

There are algorithms that are computing explanations on the level of single layer. You have to select one layer to explain against it. Many algorithms are using only `Conv2d` layers to explain. In the cell below we are fetching last convolutional layer from the network to explain.

In [None]:
layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]

Next, You have to pass additional parameters to selected CVClassificationExplainers. Our context manager accepts objects of `ExplainerWithParams` class which store additional parameters to CVClassificationExplainers. In the cell below we are creating two explainers with additional `layer` arguments. 

### Choosing foxai explainers (layer-specific algorithms)

To explain the operation of the model at the level of a single network layer, pass an object representing the model layer to the `ExplainerWithParams` class, as shown in the cell below.

In [None]:
# define list of explainers we want to use
# full list of supported explainers is present in `Explainers` enum class.
explainer_list = [
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_GUIDEDGRADCAM_EXPLAINER, layer=layer),
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER, layer=layer),
]

### Explaining the predictions (layer-specific algorithms)

Explanation code looks the same. We don't have to change anything here.

In [None]:
sample_counter = 0

# iterate over dataloader
for sample_batch in val_dataloader:
    sample_list, label_list = sample_batch
    # iterate over all samples in batch
    for sample, label in zip(sample_list, label_list):
        # add batch size dimension to the data sample
        input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)
        category_name = categories[label.item()]
        with FoXaiExplainer(
            model=model,
            explainers=explainer_list,
            target=label,
        ) as xai_model:
            # calculate attributes for every explainer
            _, attributes_dict = xai_model(input_data)

        for key, value in attributes_dict.items():
            # create directory for every explainer artifacts
            artifact_explainer_dir = os.path.join(artifact_dir, key)
            if not os.path.exists(artifact_explainer_dir):
                os.makedirs(artifact_explainer_dir)

            # create figure from attributes and original image           
            figure = mean_channels_visualization(attributions=value[0], transformed_img=sample, title= f"Mean of channels ({key})")

            # save figure to artifact directory
            figure.savefig(os.path.join(artifact_explainer_dir, f"artifact_{sample_counter}_{category_name}.png"))
            show_figure(figure)
            
        sample_counter += 1
        # if we processed desired number of samples break the loop
        if sample_counter > max_samples_explained:
            break

    # if we processed desired number of samples break the loop
    if sample_counter > max_samples_explained:
        break

## Congratulations

You have learned to use the basic functionality of the library. You can now experiment and gain confidence in your ML models.