In [1]:
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import torchxai.model as xaim
from collections import defaultdict
# from src.models import CnnMnist, ResNetMnist, ResNetMnistCBAM
from src.utils import *
from src.trainsettings import ModelTranier

In [2]:
parser = argument_parsing(preparse=True)
args_string = \
"""-pp  ../XAI
-dp  ../data
-rf  mnist-resnet-roar
-dt  mnist
-et  roar
-at  vanillagrad  inputgrad  guidedgrad  gradcam
-mt  resnet  resnetcbam
-bs  256
-ns  10
-sd  73""".replace("\n", "  ").split("  ")
args = parser.parse_args(args_string)

In [3]:
trainer = ModelTranier()
_, test_dataset, _, test_loader = trainer.build_dataset(args)

# Attention Methods in XAI

In [4]:
torch.manual_seed(args.seed)
imgs_dict = build_img_dict(test_dataset)
imgs, labels = get_samples(imgs_dict, cat="all", sample_size=1)

In [12]:
models_attrs = defaultdict()
for m_type in args.model_type:
    m_class = trainer.model_dict[args.data_type][m_type]
    models_attrs[f"{m_type}"] = defaultdict()
    models_attrs[f"{m_type}"]["attrs"] = []
    attr_models = []
    for a_type in args.attr_type:
        a_class = trainer.attr_dict[a_type]
        lpath = Path(args.prj_path)/"trained"/args.data_type/args.eval_type/f"{m_type}-first.pt"
        kwargs = trainer.get_kwargs_to_attr_model(m_type, a_type)
        model, attr_model = trainer.create_attr_model(m_class, a_class, str(lpath), **kwargs)
        models_attrs[f"{m_type}"]["model"] = model
        if models_attrs[f"{m_type}"].get("test") is None:
            test_loss, test_acc = trainer.test(model, test_loader, device="cpu")
            models_attrs[f"{m_type}"]["test"] = (test_loss, test_acc)
        models_attrs[f"{m_type}"]["attrs"].append(attr_model)

In [15]:
models_attrs["resnet"]["test"]

(0.030985720507637596, 99.14)

In [16]:
models_attrs["resnetcbam"]["test"]

(0.027761919576558283, 99.2)

In [17]:
resnet = models_attrs["resnet"]

torch.Size([10, 1, 28, 28])