# Example notebook for Image Classification (with imagenet-mini)

In this tutorial, You will see how You can use explainable algorithms to study pre-trained model decision. We will take a pre-trained model, sample images and run several explainable methods.

### Setup 

#### Imports

First we have to import all necessary libraries.

In [None]:
# import necessary libraries
import os
import torch
from pytorch_lightning import LightningModule
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, ImageNet

from foxai.context_manager import FoXaiExplainer, ExplainerWithParams, CVClassificationExplainers
from foxai.visualizer import mean_channels_visualization, single_channel_visualization

from IPython.display import Markdown, display

Install missing libraries required by `YOLOv5` that are not part of `foxai` package.

In [None]:
!pip install scipy opencv-python seaborn

Configure `CUDA_LAUNCH_BLOCKING=1` to prevent issues with `CUDA` while running GPU-accelerated computations in notebook.

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

#### Downloading missing models

Download `YOLOv5` and `ImageNet.yaml` files from https://github.com/ultralytics/yolov5 if not present in local storage.

In [None]:
# check if YOLOv5 model and ImageNet.yaml files are present at local storage and if they are not download them
![ ! -f "yolov5s-cls.pt" ] && wget https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s-cls.pt
![ ! -f "ImageNet.yaml" ] && wget https://raw.githubusercontent.com/ultralytics/yolov5/master/data/ImageNet.yaml

#### Define custom functions

Define custom function to visualize figures

In [None]:
import matplotlib.pyplot as plt

# function to enable displaying matplotlib Figures in notebooks
def show_figure(fig): 
    dummy = plt.figure()
    new_manager = dummy.canvas.manager
    new_manager.canvas.figure = fig
    new_manager.set_window_title("Test")
    fig.set_canvas(new_manager.canvas)
    plt.show()
    return dummy

Define function that will load model, list of labels and transformation function of a desired model. Currently we support, in this notebook, only a few models: `VGG11`, `ResNet50`, `ViT`, `MobileNetV3` and `YOLOv5`. You can easilly add new models from `torchvision` model zoo and even define Your own model.

In [None]:
import yaml
from yaml.loader import SafeLoader
from torchvision.transforms._presets import ImageClassification
from typing import Tuple, List


def load_model() -> Tuple[torch.nn.Module, List[str], ImageClassification]:
    """Load model, label list and transformation function used in data preprocessing.

    Returns:
        Tuple of model, list of labels and transformation function.
    """
    weights = torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1

        # load model from torchvision model zoo
    model = torchvision.models.efficientnet_b0(weights=weights)

    # get class names
    categories = weights.meta["categories"]
    transform = weights.transforms()

    return model, categories, transform

### Configuration

Cell below contains configuration of this notebook. We have defined max number of samples to be saved in artifact directory, path to `ImageNet-Mini` dataset downloaded from [Kaggle](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000), name of the model, batch_size and device to be used.

In [None]:
batch_size: int = 1
max_samples_explained: int = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# define directory where explanation artifacts will be stored
artifact_dir: str = f"artifacts/"

# `data_dir` variable contains path to dataset downloaded from https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000.
# You have to register in Kaggle to be able to download this dataset.
data_dir: str = "/home/user/Downloads/imagenet-mini"


### Loading the model

Load specified model, put it in evaluation mode, place it on specified device, download and preprocess `ImageNet-Mini` dataset. Trasformation function is used to match training dataset preprocessing steps.

In [None]:
# load model, classes and transformation function
model, categories, transform = load_model()

# put model in evaluation mode
model.eval()

# place model on specified device (CPU or GPU)
model.to(device)

# load test dataset - ImageNet-Mini downloaded from Kaggle: https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000
imagenet_val = torchvision.datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)
val_dataloader = DataLoader(imagenet_val, batch_size=batch_size)

In [None]:
# instruct notebook to display figures inline
%matplotlib inline

Let's see how images from `ImageNet-Mini` looks like. We will display first few samples of dataset. In the following steps we will use them to explain model predictions using different explainable algorithms.

## Demo for general algorithms 

In [None]:
# define list of explainers we want to use
# full list of supported explainers is present in `Explainers` enum class.
explainer_list = [
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_GRADIENT_SHAP_EXPLAINER),
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INPUT_X_GRADIENT_EXPLAINER),
    ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INTEGRATED_GRADIENTS_EXPLAINER),
]

## Demo for layer specific algorithms

In [None]:
layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]

In [None]:
from foxai.metrics import insertion, deletion
from foxai.visualizer import visualize_metric

In [None]:
type(model)

In [None]:

# iterate over dataloader
sample_batch = next(iter(val_dataloader))
# iterate over all samples in batch
sample, label = sample_batch[0][0], sample_batch[1][0]
# add batch size dimension to the data sample
input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)
category_name = categories[label.item()]
with FoXaiExplainer(
    model=model,
    explainers=[ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER, layer=layer)],
    target=label,
) as xai_model:
    # calculate attributes for every explainer
    first_output, attributes_dict = xai_model(input_data)
    value = attributes_dict["CV_LAYER_GRADCAM_EXPLAINER"]
    figure = mean_channels_visualization(attributions=value[0], transformed_img=sample, title= f"Mean of channels)")
    # save figure to artifact directory
    show_figure(figure) 
    
    gradcam_maps = attributes_dict["CV_LAYER_GRADCAM_EXPLAINER"]
    value = gradcam_maps[0]
    chosen_class = first_output.argmax()
    insertion_result, importance_lst = insertion(value, sample, model, chosen_class)
    visualize_metric(importance_lst, insertion_result, metric_type="Insertion")
    deletion_result, importance_lst = deletion(value,sample, model, chosen_class)
    visualize_metric(importance_lst, deletion_result, metric_type="Deletion")
        
