# YOLO_v5 XAI example

## 1. Import libraries

In [None]:
from typing import Final, List, Tuple
import os

import cv2
import numpy as np
import torch
import torchvision
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image
import IPython

from autoxai.context_manager import AutoXaiExplainer, Explainers, ExplainerWithParams
from autoxai.explainer.base_explainer import CVExplainer
from example.yolov5_exmaple.yolo_utils import (  # scale_boxes,
    get_variables,
    letterbox,
    make_divisible,
    non_max_suppression,
    xywh2xyxy,
    MAXIMUM_BBOX_WIDTH_HEIGHT,
    MAXIMUM_NUMBER_OF_BOXES_TO_NMS,
)

## 2. Define target class to be explained.

In [None]:
TARGET: Final[int] = 0
""" The target class to be explained with XAI.

For yolo it takes all preditctions belonging to the given class.
It means that if 2 persons were detected, the XAI will be computed
for both of them. Instance specific XAI requires code modification.
"""

## 3. Craete model and open sample image.

Downaload YOLO model from torchhub and open sample image.
Check for CUDA device availability and move the model to CUDA if available.
Get number of classes and stride, used by the given YOLO model.

In [None]:
model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True)
image = Image.open(os.path.join(os.path.dirname(os.getcwd()),"images/zidane.jpg"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device=device)
params = dict(get_variables(model=model, include=("names", "stride")))

## 4. Custom YOLO wrapper for XAI

By default YOLO model downloaded from torchhub is already a wrapper on the YOLO model.
It has model.model(...), which performs the regular YOLO prediction and custom model(...) method (forward(...) method), which do following steps:  
- image pre-processing (image, resizing, scaling, padding and normalization)
- model.model(...) call YOLO model on the pre-processed image
- non-max suppression algorithm, that removes duplicated bounding boxes and predictions with low confidence
- output bounding-box scaling
- visualization optional

The output of the model(...) function is already postprocessed and contains only classes with high confidence and reduced number of bboxes.

YOLO models output.shape:

a) model(...)  
* if we feed the `torch.Tensor` it only runs model.model(...) and the output is of shape = <span style="color:orange">[B,15120,85].</span>, where 15120 are all object predictions, and 85 is:
    * bbox = [:4]
    * object confidence [5]
    * class confidence [6:] ,(80 classes)
* if we feed an numpy/PIL image it runs the whole pre- and post-processing pipeline described above.  The output object is of class `model.common.Detections` and contains multiple fileds like `pred:List[torch.Tensor] = [[num_detections,6]]`, `xywh:List[torch.Tensor] = [[num_detections,6]]`,`xyxy:List[torch.Tensor] = [[num_detections,6]]` and other fields. `6` is an output of the non-max suppression algorithm used in YOLO. [:4] are bboxes, [5] is confidence and [6] is class number.

b) model.model(...)
* we get same results as running model(...) with `torch.Tensor`. See above.


**Most XAI algorithms require however output to be of shape:** <span style="color:orange">[B,number_of_classes]</span>

, where B is a  batch. It has also to be fully differentiable.

In order to get explanation for YOLO model we need to modify the YOLO model output shape `[B,1520,85] -> [B,80]`. The `XaiYoloWrapper` below does exactly this.

In [None]:
class XaiYoloWrapper(torch.nn.Module):
    """The Xai wrapper for the yolo model.

    Most explainers except the model output to consist
    only classes. However many models have custom outputs.
    In case of YOLO_v5 the output is [N,6] tensor, where:
        - N: is the number of predicted objects
        - [x,y,w,h,confidence,class]

    In order to make the output usable by regular
    xai expaliners, we need to convert the output
    to the following shape:
        - [cls1, cls2, ....., clsM], where M is
        the total number of classes.

    We loose the information about number of predictions
    and thier locations. In order to get those data
    we need to run the regular inference separately.
    """

    def __init__(
        self, model, conf: float = 0.25, iou: float = 0.45
    ) -> None:
        """
        Args:
            model: the yolo model to be used.
            conf: confidence threshold for predicted objects
            iou: iou threshold for preddicted bboxes for nms algorithm
        """
        super().__init__()
        self.model = model
        self.training = model.training

        params = dict(get_variables(model=model, include=("names")))
        self.number_of_classes: int = len(params["names"])
        self.conf: float = conf
        self.iou: float = iou

    def xai_non_max_suppression(
        self,
        prediction: torch.Tensor,
        conf_thres: float = 0.25,
        iou_thres: float = 0.45,
        agnostic: bool = False,
        max_det: int = 300,
    ) -> torch.Tensor:
        """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

        Args:
            prediction: the model prediction
            conf_thres: confidence threshold, for counting the detection as valid
            iou_thres: intersection over union threshold for non max suppresion algorithm
            agnostic: if True, non max suppression algorithm is run on raw bboxes. However it may
                happen that different classes have bboxes in similar place. NMS would discard one of
                those bboxes and keep only the one with higher confidence. If we want to keep bboxes
                that are in similar place, but have different class label, we should set agnostic to False.
            max_det: maximum number of detections

        Returns:
            batch of detections of shape (B,80), where 80 are classes confidence

        Example output:
                cls0_conf   cls1_conf   ....    cls79_conf
            0  0.005       0.00002     ...     0.87002
            1  0.535       0.20002     ...     0.08002
            .  ...         ...         ...     ...
            B  0.00008     0.10302     ...     0.0289
        """

        # Checks
        assert (
            0 <= conf_thres <= 1
        ), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
        assert (
            0 <= iou_thres <= 1
        ), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
        if isinstance(
            prediction, (list, tuple)
        ):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
            prediction = prediction[0]  # select only inference output

        device = prediction.device
        mps = "mps" in device.type  # Apple MPS
        if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
            prediction = prediction.cpu()

        # calculate batch size
        batch_size = prediction.shape[0]  # batch size

        # calculate True/False list for all predictions that meet confidence threshold
        # criteria for all samples in batch
        xc = prediction[..., 4] > conf_thres  # candidates

        # set start mask index as first element after the last class index
        mi = 5 + self.number_of_classes  # mask start index

        # create empty tensor representing result - probability of each class
        out_predictions = torch.zeros(
            (batch_size, self.number_of_classes), device=prediction.device
        )
        class_confidence_range = range(5, mi)

        # pylint: disable = unnecessary-comprehension
        for xi, _ in enumerate([jj for jj in prediction]):

            # get sample prediction
            x = prediction[xi]

            # get all anchors that meet confidence threshold criteria from a single sample from a batch
            x_high_conf = x.detach()[xc[xi]]

            # if none of the anchors meet threshold criteria
            if not x_high_conf.shape[0]:
                # get class outputs
                x = x[:, class_confidence_range]
                # set all class outputs to zero (no gradient)
                # does outputs does not contribute to the final prediction
                x *= 0
                # sum same classes together to get the shape [number_of_objectes,num_of_classes] -> [num_of_classes]
                out_predictions[xi] = x.sum(dim=0)
                continue

            # multiply class probability by confidence score
            x_high_conf[:, 5:] *= x_high_conf[:, 4:5]  # conf = obj_conf * cls_conf

            # get bounding box dimensions
            box = xywh2xyxy(x_high_conf[:, :4])

            # get confidence and argmax of classes for all anchors
            conf, j = x_high_conf[:, class_confidence_range].max(1, keepdim=True)

            # overwrite anchors that meet confidence criteria with
            # bounding box dimensions, confidence, probability and
            x_high_conf = torch.cat((box, conf, j.float()), dim=1)[
                conf.view(-1) > conf_thres
            ]

            # get number of anchors
            number_of_anchors = x_high_conf.shape[0]  # number of boxes

            # if no anchors are present
            if not number_of_anchors:  # no boxes
                # get only class confidence
                x = x[:, 5:]
                # set all class outputs to zero (no gradient)
                # does outputs does not contribute to the final prediction
                x *= 0
                # sum same classes together to get the shape [number_of_objectes,num_of_classes] -> [num_of_classes]
                out_predictions[xi] = x.sum(dim=0)
                continue

            # get indices of predictions by confidence score in descending order
            x_indexs = x_high_conf[:, 4].argsort(descending=True)[:MAXIMUM_NUMBER_OF_BOXES_TO_NMS]

            # get predictions by confidence in descending order
            x_high_conf = x_high_conf[x_indexs]

            # get class
            c = x_high_conf[:, 5:6] * (0 if agnostic else MAXIMUM_BBOX_WIDTH_HEIGHT)  # classes
            # get boxes (with offset by class) and scores
            boxes, scores = (
                x_high_conf[:, :4] + c,
                x_high_conf[:, 4],
            )

            # get bounding boxes indices to keep from NMS
            selected_indices = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

            # limit detections to specified number
            selected_indices = selected_indices[:max_det]  # limit detections

            # get indexes of x_high_conf tensor, with high confidence and non-overlaping bboxes
            x_indexs = x_indexs[selected_indices]

            # grab indexes x  tensor, with high confidence and non-overlaping bboxes
            pick_indices = xc[xi].nonzero()[x_indexs]

            # get classes confidence
            # in place opeartions not supported for gradient computation
            # we need to clone the tensor and keep track of gradient
            class_confidence = x[:, class_confidence_range].clone()
            if class_confidence.requires_grad:
                class_confidence.retain_grad()
            # get object confidence
            object_confidence = x[:, 4:5].clone()
            if object_confidence.requires_grad:
                object_confidence.retain_grad()

            # multiply class confidence by object confidence
            x[:, class_confidence_range] = class_confidence * object_confidence

            # retain only classes predictions
            x = x[:, class_confidence_range]

            # create mask of anchors and mark selected
            mask = torch.zeros_like(x)
            mask[pick_indices] = 1

            # erase non-selected anchors
            x = x * mask

            # sum probabilities of classes over all anchors
            # instance confidence to semantic confidence
            out_predictions[xi] = x.sum(dim=0)

        return out_predictions

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        x = self.xai_non_max_suppression(
            x,
            conf_thres=self.conf,
            iou_thres=self.iou,
            agnostic=False,
            max_det=1000,
        )  # NMS

        return x

## 5. Create a wrapped model

In [None]:
yolo_model = XaiYoloWrapper(model=model.model, conf=model.conf, iou=model.iou).to(
    device=device
)

## 6. Pre-process image

In [None]:
def pre_process(
    image: np.ndarray, sample_model_parameter: torch.Tensor, stride: int
) -> Tuple[torch.Tensor, List[int], List[int]]:
    """Transform the input image to the yolo network.

    Args:
        image: the input image to the network
        sample_model_parameter: the model parameter is used to read
            the model type (fp32/fp16) and target device
        stride: the yolo network stride

    Retuns:
        tensor image ready to be feed into the network
    """
    size: Tuple[int, int] = (640, 640)
    if image.shape[0] < 5:  # image in CHW
        image = image.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
    image = (
        image[..., :3] if image.ndim == 3 else cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    )  # enforce 3ch input
    shape0 = image.shape[:2]  # HWC
    g = max(size) / max(shape0)  # gain
    shape1 = [int(y * g) for y in shape0]
    np_image = image if image.data.contiguous else np.ascontiguousarray(image)  # update
    shape1 = [make_divisible(x, stride) for x in np.array(shape1)]  # inf shape
    x = letterbox(np_image, shape1, auto=False)[0]  # pad
    x = np.ascontiguousarray(
        np.expand_dims(np.array(x), axis=0).transpose((0, 3, 1, 2))
    )  # stack and BHWC to BCHW
    x = (
        torch.from_numpy(x)
        .to(sample_model_parameter.device)
        .type_as(sample_model_parameter)
        / 255
    )  # uint8 to fp16/32

    return x, shape0, shape1

input_image, _, _ = pre_process(
    image=np.asarray(image),
    sample_model_parameter=next(model.parameters()),
    stride=params["stride"],
)
input_image = input_image.to(device)

## 7. Run XAI

In [None]:
with AutoXaiExplainer(
    model=yolo_model,
    explainers=[
        ExplainerWithParams(
            explainer_name=Explainers.CV_LAYER_GRADCAM_EXPLAINER,
            target=0,
        )
    ],
) as xai_model:
    _, attributions = xai_model(input_image)

attributions = attributions["CV_LAYER_GRADCAM_EXPLAINER"]

## 8. Run regular inference

Because the XAI was ran using wrapped model, the inference results from XAI are of shape `[B,80]`. The information about bboxes as well as about each instance is lost. We have only information about semantic classes detected in an image. Therefore in order to get instance level information with bboxes, we need to run the regular model once again.

In [None]:
# standard inference
y = model.model(input_image)
y = non_max_suppression(
    y,
    conf_thres=model.conf,
    iou_thres=model.iou,
    classes=None,
    agnostic=False,
    multi_label=False,
    max_det=1000,
)  # NMS

# uncomment in order to scale boxes to the original image size
# scale_boxes(shape1, y[0][:, :4], shape0)
y = y[0].detach().cpu()
bboxes = y[:, :4]
normalized_image: torch.Tensor = (
    (
        (input_image - torch.min(input_image))
        / (torch.max(input_image) - torch.min(input_image))
        * 255
    )
    .type(torch.uint8)
    .squeeze()
)
labels: List[str] = [params["names"][label.item()] for label in y[:, -1]]
pred_image = torchvision.utils.draw_bounding_boxes(
    image=normalized_image,
    boxes=bboxes,
    labels=labels,
)

## 9. Visualize results

In [None]:
figure = CVExplainer.visualize(
    attributions=attributions.squeeze(), transformed_img=pred_image
)
canvas = FigureCanvas(figure)
canvas.draw()
buf = canvas.buffer_rgba()
image = np.asarray(buf)
cv2.putText(
    img=image,
    text=params["names"][TARGET],
    org=(image.shape[0] // 2, image.shape[1] // 8),
    fontFace=cv2.FONT_HERSHEY_SIMPLEX,
    fontScale=image.shape[0] / 500,
    color=(0, 0, 0),
)
IPython.display.display(Image.fromarray(image))