In [1]:
import torch
import torchvision.models as models

from idsds.models.resnet import resnet18, resnet50, resnet101, resnet152, wide_resnet50_2
from idsds.models.vgg import vgg16, vgg16_bn, vgg13, vgg19, vgg11
from idsds.models.ViT.ViT_new import vit_base_patch16_224
from idsds.models.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from idsds.models.bagnets.pytorchnet import bagnet33
from idsds.models.xdnns.xfixup_resnet import xfixup_resnet50, fixup_resnet50
from idsds.models.xdnns.xvgg import xvgg16
from idsds.models.bcos_v2.bcos_resnet import resnet50 as bcos_resnet50
from idsds.models.bcos_v2.bcos_resnet import resnet18 as bcos_resnet18

import utils

original_models = "/workspace/hd/original/"
tuned_models = "/workspace/hd/tuned/"

test_loader = utils.get_loader()

In [86]:
import torch.nn.functional as F
import matplotlib.pyplot as plt

def compare_mad_attr(model1, model2, attr_fn, loader):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model1.to(device).eval()
    model2.to(device).eval()

    total_mad = 0
    num_samples = 0

    cam1 = attr_fn(model1)
    cam2 = attr_fn(model2)

    for images, _ in loader:
        images = images.to(device)

        for i in range(images.size(0)):
            input_tensor = images[i].unsqueeze(0)
    
            output1 = model1(input_tensor)
            pred_class1 = output1.argmax(dim=1)[0].item()
            attribution_map1 = cam1(pred_class1, output1)[0]
        
            output2 = model2(input_tensor)
            pred_class2 = output2.argmax(dim=1)[0].item()
            attribution_map2 = cam2(pred_class2, output2)[0]
        
            if attribution_map1.ndim == 3 and attribution_map1.shape[0] == 1:
                attribution_map1 = attribution_map1.squeeze(0)
            if attribution_map2.ndim == 3 and attribution_map2.shape[0] == 1:
                attribution_map2 = attribution_map2.squeeze(0)
            
            map1_resized = F.interpolate(attribution_map1.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze()
            map2_resized = F.interpolate(attribution_map2.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze()
        
            total_mad += torch.abs(map1_resized - map2_resized).mean().item()
        num_samples += images.size(0)
        print(f"{num_samples} {total_mad}")

    mad = total_mad / num_samples
    print(f"Mean Absolute Difference between attribution maps: {mad}")

test_loader = get_loader(batch_size=100)

In [87]:
from torchcam.methods import GradCAM

resnet50_ood = resnet.resnet50(pretrained=True)
resnet50_id = resnet.resnet50(pretrained=True)
resnet50_id = lsd(
    tuned_models + "resnet50_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    resnet50_id
)
compare_mad_attr(resnet50_ood, resnet50_id, GradCAM, test_loader)

model loaded
model loaded


  checkpoint = torch.load(path)


100 4.889074240811169
200 8.708673507906497
300 13.189132812432945
400 18.303034431301057
500 23.023825705051422
600 28.772354780696332
700 33.559516292996705
800 38.5736175570637
900 43.79189453180879
1000 48.759562991559505
1100 54.14293042104691
1200 60.50546339992434
1300 65.45150360558182
1400 69.92538102716208
1500 74.60425293352455
1600 79.92344077676535
1700 84.79647032823414
1800 89.7522300593555
1900 93.77192197926342
2000 97.48797635454684
2100 102.91816969402134
2200 108.09968297276646
2300 113.0479286890477
2400 117.74490522779524
2500 123.32907906360924
2600 127.93321824260056
2700 132.50750592537224
2800 136.89259670861065
2900 141.8336002510041
3000 146.1190961105749
3100 150.85396315902472
3200 155.17801814898849
3300 159.8392039416358
3400 165.02810506056994
3500 170.41931820567697
3600 174.66147010121495
3700 179.55783457309008
3800 183.8782892683521
3900 188.3820141190663
4000 194.48068280518055
4100 200.21075364761055
4200 208.47729245759547
4300 214.8367188833654
