In [None]:
!pip install foxai yolov5

In [None]:
"""Example of running XAI on YOLOv5 on Object Detection task."""

import os
import cv2
import numpy as np
import torch
import torchvision
from typing import Tuple
from custom_models.ssd.ssd_object_detector import SSDObjectDetector
from custom_models.yolov5.model import WrapperYOLOv5ObjectDetectionModel
from custom_models.yolov5.yolo_object_detector import (
    YOLOv5ObjectDetector,
    get_yolo_layer,
)
from PIL import Image
from torchvision.models._meta import _COCO_CATEGORIES
from torchvision.models.detection import SSD300_VGG16_Weights

from foxai.explainer.computer_vision.algorithm.gradcam import (
    LayerGradCAMObjectDetectionExplainer,
    ObjectDetectionOutput,
)
from foxai.visualizer import object_detection_visualization


def main(model_name: str, img_path: str, output_dir: str, layer_name: str, img_size: Tuple[int, int]):
    """Run YOLO_v5 XAI and save results."""

    image = Image.open(img_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    preprocessed_image = YOLOv5ObjectDetector.preprocessing(
        img=np.asarray(image),
        new_shape=img_size,
        change_original_ratio=True,
    ).to(device)

    preprocessed_image_shape = preprocessed_image.shape[-2:]

    if "ssd" in model_name:
        weights = SSD300_VGG16_Weights.COCO_V1
        model = (
            torchvision.models.detection.ssd300_vgg16(weights=weights, pretrained=True)
            .eval()
            .to(device)
        )
        preprocess = weights.transforms()
        model.detections_per_img = 10
        input_image = preprocess(preprocessed_image).to(device)

        model_wrapper = SSDObjectDetector(model=model, class_names=_COCO_CATEGORIES)
        target_layer = model_wrapper.model.backbone.features[-1]
    else:
        assert preprocessed_image.shape[-1] % 32 == 0, f"Image shape have to be divisible by 32. Current shape is {preprocessed_image.shape}"
        assert preprocessed_image.shape[-2] % 32 == 0, f"Image shape have to be divisible by 32. Current shape is {preprocessed_image.shape}"
        model = torch.hub.load("ultralytics/yolov5", model_name, pretrained=True)
        names = model.model.names
        wrapper_model = WrapperYOLOv5ObjectDetectionModel(
            model=model.model.model,
            device=device,
        )
        model_wrapper = YOLOv5ObjectDetector(
            model=wrapper_model,
            img_size=preprocessed_image_shape,
            names=names,
        )
        target_layer = get_yolo_layer(
            model=model_wrapper,
            layer_name=layer_name,
        )

        input_image = preprocessed_image

    saliency_method = LayerGradCAMObjectDetectionExplainer(
        model=model_wrapper,
        target_layer=target_layer,
    )
    outputs: ObjectDetectionOutput = saliency_method(input_img=input_image)
    final_image = object_detection_visualization(
        detections=outputs,
        input_image=input_image,
    )
    img_name = f"{output_dir}/yolo_gradcam.png"
    output_path = f"./{img_name}"
    os.makedirs(".", exist_ok=True)
    print(f"[INFO] Saving the final image at {output_path}")
    cv2.imwrite(output_path, final_image)

In [None]:
model_name = "yolov5s"
img_path = "../images/zidane.jpg"
output_dir = "./"
layer_name = "model_23_cv3_act"
resize_to_img_size = (640, 640)

main(
    model_name=model_name,
    img_path=img_path,
    output_dir=output_dir,
    layer_name=layer_name,
    img_size=resize_to_img_size,
)