In [16]:
import torch
import torchvision.models as models

from idsds.models.resnet import resnet18, resnet50, resnet101, resnet152, wide_resnet50_2
from idsds.models.vgg import vgg16, vgg16_bn, vgg13, vgg19, vgg11
from idsds.models.ViT.ViT_new import vit_base_patch16_224
from idsds.models.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from idsds.models.bagnets.pytorchnet import bagnet33
from idsds.models.xdnns.xfixup_resnet import xfixup_resnet50, fixup_resnet50
from idsds.models.xdnns.xvgg import xvgg16
from idsds.models.bcos_v2.bcos_resnet import resnet50 as bcos_resnet50
from idsds.models.bcos_v2.bcos_resnet import resnet18 as bcos_resnet18

import utils

original_models = "/workspace/hd/original/"
tuned_models = "/workspace/hd/tuned/"

test_loader = utils.get_loader()

In [9]:
import torch.nn.functional as F

def compare_mad(model1, model2, loader):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    model1.to(device).eval()
    model2.to(device).eval()
    
    num_samples = 0
    logit_mad_total = 0
    softmax_mad_total = 0
    
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
    
            # logits
            logits1 = model1(images)
            logits2 = model2(images)

            # logit mad
            logit_mad = torch.abs(logits1 - logits2).mean(dim=1)
            logit_mad_total += logit_mad.sum().item()
    
            # softmax
            probs1 = F.softmax(logits1, dim=1)
            probs2 = F.softmax(logits2, dim=1)
    
            # softmax mad
            softmax_mad = torch.abs(probs1 - probs2).mean(dim=1)
            softmax_mad_total += softmax_mad.sum().item()
            
            num_samples += images.size(0)
            
    logit_mad_score = logit_mad_total / num_samples
    softmax_mad_score = softmax_mad_total / num_samples
    print(f"Softmax MAD: {softmax_mad_score}")
    print(f"Logit MAD:  {logit_mad_score}")

In [10]:
resnet18_ood = resnet18(pretrained=True)
resnet18_id = resnet18(pretrained=True)
resnet18_id = utils.load_state_dict(
    tuned_models + "resnet18_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    resnet18_id
)
compare_mad(resnet18_ood, resnet18_id, test_loader)

model loaded
model loaded


  checkpoint = torch.load(path)


Softmax MAD: 0.0002470408531837165
Logit MAD:  0.4373722342967987


In [13]:
fixup_resnet50_ood = fixup_resnet50()
fixup_resnet50_id = fixup_resnet50()
fixup_resnet50_id = utils.load_state_dict(
    tuned_models + "fixup_resnet50_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    fixup_resnet50_id
)

  checkpoint = torch.load(path)


In [19]:
x_resnet50_ood = xfixup_resnet50()
x_resnet50_id = xfixup_resnet50()
x_resnet50_id = utils.load_state_dict(
    tuned_models + "xresnet50_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    x_resnet50_id
)

  checkpoint = torch.load(path)


In [21]:
x_vgg16_ood = xvgg16(pretrained=True)
x_vgg16_id = xvgg16(pretrained=True)
x_vgg16_id = utils.load_state_dict(
    tuned_models + "xvgg16_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    x_vgg16_id
)

  checkpoint = torch.load(path)


In [23]:
bagnet33_ood = bagnet33(pretrained=True)
bagnet33_id = bagnet33(pretrained=True)
bagnet33_id = utils.load_state_dict(
    tuned_models + "bagnet33_imagenet1000_lr0.001_epochs30_step10_checkpoint_best.pth.tar", 
    bagnet33_id
)
compare_mad(bagnet33_ood, bagnet33_id, test_loader)

MODEL LOADED
MODEL LOADED


  checkpoint = torch.load(path)


KeyboardInterrupt: 