Dataset can be downloaded from: https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000

In [1]:
import os
import torch
from pytorch_lightning import LightningModule
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, ImageNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!pip install scipy



In [3]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [4]:
batch_size: int = 256 if torch.cuda.is_available() else 64
max_epochs: int = 3
max_samples_explained: int = 10
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
data_dir: str = "/home/user/Downloads/imagenet-mini"

# Define PyTorch model
weights = torchvision.models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
model = torchvision.models.mobilenet_v3_small(weights=weights).eval()
model.to(device)
categories = weights.meta["categories"]
transform = weights.transforms()

imagenet_train = torchvision.datasets.ImageFolder(root=f"{data_dir}/train", transform=transform)
imagenet_val = torchvision.datasets.ImageFolder(root=f"{data_dir}/val", transform=transform)

In [5]:
from autoxai.explainer.base_explainer import CVExplainer
from autoxai.context_manager import AutoXaiExplainer, ExplainerWithParams, Explainers

In [6]:
explainer_list = [
    ExplainerWithParams(explainer_name=Explainers.CV_GRADIENT_SHAP_EXPLAINER),
    ExplainerWithParams(explainer_name=Explainers.CV_NOISE_TUNNEL_EXPLAINER),
]

val_dataloader = DataLoader(imagenet_val, batch_size=batch_size)
artifact_dir: str = "artifacts/mobilenetv3/"
sample: torch.Tensor
label: int

counter: int = 0
exp: CVExplainer
for sample_batch in val_dataloader:
    sample_list, label_list = sample_batch
    for sample, label in zip(sample_list, label_list):
        label_int = label
        input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)
        with AutoXaiExplainer(
            model=model,
            explainers=explainer_list,
        ) as xai_model:
            _, attributes_dict = xai_model(input_data)

        for key, value in attributes_dict.items():
            # create directory for every explainer artifacts
            artifact_explainer_dir = os.path.join(artifact_dir, key)
            if not os.path.exists(artifact_explainer_dir):
                os.makedirs(artifact_explainer_dir)

            figure = CVExplainer.visualize(attributions=value, transformed_img=sample)
            figure.savefig(os.path.join(artifact_explainer_dir, f"artifact_{counter}_{categories[label]}.png"))

        counter += 1
        if counter > max_samples_explained:
            break
    break


2023-01-10 16:07:41,586 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:07:51,518 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:02,314 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:10,554 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:16,722 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:22,787 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:28,866 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:35,033 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2023-01-10 16:08:41,042 INFO autoxai.explainer.base_explainer - No negative attributes in the explained model.
2