In [1]:
from inspect import getsource
import math
import json
from PIL import Image
from tqdm.notebook import tqdm
from pathlib import Path
from glob import glob

import numpy as np

import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from torchvision.models import ResNet34_Weights


import matplotlib.pyplot as plt
plt.style.use('ggplot')


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.flatten(x, 1)
    
    
class TillConv(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.base_module = module

    def forward(self, x):
        self = self.base_module
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        return out, identity
    
class FromConvOn(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.base_module = module   
        
    def forward(self, x, identity):
        self = self.base_module
        
        out = x
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(identity)
        out += identity
        out = self.relu(out)
        return out


model = models.resnet34(weights='DEFAULT')
model.eval()

# layers = [
#     model.conv1,
#     model.bn1,
#     model.relu,
#     model.maxpool,
#     model.layer1,
#     model.layer2,
#     model.layer3,
#     model.layer4[0],
#     model.layer4[1],
#     TillConv(model.layer4[2]),
#     FromConvOn(model.layer4[2]),
#     model.avgpool,
#     Flatten(),
#     model.fc
# ]
layers = [
    model.conv1,
    model.bn1,
    model.relu,
    model.maxpool,
    model.layer1,
    model.layer2,
    model.layer3,
    model.layer4,
    model.avgpool,
    Flatten(),
    model.fc
]

# with open('imagenet_classes.json') as f:
#     imagenet_classes = json.load(f)
with open('./imagenet_class_index.json') as f:
    imagenet_classes = json.load(f)    
wnid2label = {v[0]:v[1] for i,v in imagenet_classes.items()}
wnid2index = {v[0]:i for i,v in imagenet_classes.items()}
imagenet2label = {i:v[1] for i,v in imagenet_classes.items()}



In [2]:
transform = ResNet34_Weights.IMAGENET1K_V1.transforms()


crop = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
])
totensor = transforms.ToTensor()

def normalize(x):
    vmin = x.min()
    vmax = x.max()
    return (x-vmin)/(vmax-vmin)

In [3]:
# normalize = transforms.Compose([
#     transforms.Normalize(
#         mean=[0.485, 0.456, 0.406], 
#         std=[0.229, 0.224, 0.225]
#     ),
# ])


avg = nn.AdaptiveAvgPool2d(1)
relu = nn.ReLU()
softmax = nn.Softmax(dim=1)

def compute_gradcam(img, target_class, l0, l1, shouldRelu=True):
    act = l0(img)
    act = act.detach().requires_grad_(True)
    pred = l1(act)
    
    if act.grad is not None:
        act.grad.data.fill_(0)
    pred[:,target_class].sum().backward(retain_graph=True)
    
    with torch.no_grad():
        if shouldRelu:
            gradcam = relu(
                (act * avg(act.grad)).sum(dim=1)
            )
        else:
            gradcam = (act * avg(act.grad)).sum(dim=1)
    return gradcam




In [4]:
root = '/home/jack/data/dataset/effectiveness-of-feature-attribution/Human_experiments/Dataset/'
img_fns = (
    glob(root + 'Natural/*/*')
    + glob(root + 'Dog/*/*')
)
# img_fns = glob(root + '*/*/*')
print(len(img_fns), img_fns[:1])

imageid2imgfn = {'_'.join(fn.split('/')[-1].split('_')[1:-1]):fn for fn in img_fns}
feature_attribution_imageids = set(['_'.join(fn.split('/')[-1].split('_')[1:-1]) for fn in img_fns])

1800 ['/home/jack/data/dataset/effectiveness-of-feature-attribution/Human_experiments/Dataset/Natural/correct_images/n03658185_ILSVRC2012_val_00022870_n03658185.jpeg']


In [5]:
seg_dir = '/home/jack/data/dataset/ImageNet-S/ImageNetS919/validation-segmentation/'
seg_fns = glob(seg_dir + '*/*')

imageid2segfn = {
    fn.split('/')[-1][:-4]:fn for fn in seg_fns
}

imageid2seg = {}
for imageid in tqdm(feature_attribution_imageids):
    if imageid in imageid2segfn:
        fn = imageid2segfn[imageid]
        wnid = fn.split('/')[-2]
        imageid = fn.split('/')[-1][:-4]

        pil_seg = Image.open(fn)
        seg = totensor(crop(pil_seg))*256
        seg = (seg[0] > 10).float()

        imageid2seg[imageid] = seg
    

  0%|          | 0/595 [00:00<?, ?it/s]

In [6]:
take one images from feature attribution
take resnet34
compute gradcam
compare the visualization to theirs




SyntaxError: invalid syntax (623866095.py, line 1)

In [39]:
output = []


for imageid, seg in tqdm(imageid2seg.items()):
    fn = imageid2imgfn[imageid]
    
    ## get argmax gradcam of fn
    pil_image = Image.open(fn)
    img = transform(pil_image).unsqueeze(0)
#     gradcam_target_class = model(img).argmax()
    gradcam_class = fn.split('/')[-1][:-5].split('_')[-1]
    gradcam_class = int(wnid2index[gradcam_class])
    gradcam = compute_gradcam(
        img, gradcam_class,
        nn.Sequential(*layers[:8]), 
        nn.Sequential(*layers[8:]), 
    )

    upsample = torch.nn.Upsample(size=224, mode='bilinear')
    up = upsample(gradcam.unsqueeze(0)).squeeze()
    
    saliency = (normalize(up)>0.5).float()
    iou = (  (saliency * seg).sum() / (saliency+seg).clip(0,1).sum()  ).item()
    gtc = (  (saliency * seg).sum() / (1e-4+seg.sum())  ).item()
    sc = (  (saliency * seg).sum() / (1e-4+saliency).sum()  ).item()
    
    alignment_type = 0
    if gtc < 0.25 and sc > 0.75:
        alignment_type = 1 ## under
        print('under')
    elif gtc > 0.75 and sc < 0.25:
        alignment_type = 2 ## over
        print('over')
    elif 0.35 < gtc < 0.65 and 0.35 < sc < 0.65:
        alignment_type = 3 ## partial
        print('partial')
    else:
        alignment_type = -1
        print(f'no alignment type, gtc={gtc:.4f}, sc={sc:.4f}')
        
#     'over' or 'under' or 'partial' or 'none'
    entro = entropy(up).item()
    out = dict(
        fn=fn,
        gradcam=gradcam,
#         segmentation=seg,
        iou=iou,
        gtc=gtc,
        sc=sc,
        entropy=entro,
        alignment_type=alignment_type,
    )
    output.append(out)

  0%|          | 0/114 [00:00<?, ?it/s]

no alignment type, gtc=0.7561, sc=0.9015
no alignment type, gtc=0.7342, sc=0.9880
no alignment type, gtc=0.9213, sc=0.5088
no alignment type, gtc=0.9309, sc=0.5716
no alignment type, gtc=0.8489, sc=0.8988
partial
no alignment type, gtc=0.8533, sc=0.6929
no alignment type, gtc=0.6586, sc=0.8676
no alignment type, gtc=0.0000, sc=0.0000
no alignment type, gtc=0.5583, sc=0.2270
partial
no alignment type, gtc=0.5808, sc=0.8118
no alignment type, gtc=0.0000, sc=0.0000
no alignment type, gtc=0.8846, sc=0.4588
no alignment type, gtc=0.7907, sc=0.6769
no alignment type, gtc=0.4457, sc=0.7809
over
no alignment type, gtc=0.9946, sc=0.3841
no alignment type, gtc=0.8043, sc=0.3425
no alignment type, gtc=0.4584, sc=0.9412
no alignment type, gtc=0.3060, sc=0.6241
no alignment type, gtc=0.9529, sc=0.6775
no alignment type, gtc=0.4099, sc=0.7881
no alignment type, gtc=0.6541, sc=0.9524
no alignment type, gtc=0.9349, sc=0.3108
no alignment type, gtc=0.6578, sc=0.1577
no alignment type, gtc=0.2995, sc=0.

In [40]:
output[0]['gtc'], output[0]['sc']

(0.7560861706733704, 0.9015083312988281)

In [41]:
def entropy(x):
    p = x / x.sum()
    p = p.flatten()
    p = p[p>0]
    res = (-p * np.log(p)).sum()
    return res

In [42]:
torch.save(output, 'output.pth')

In [47]:
column_names = [
    'user_id',
    'gtc_sc_group',
    'iou',
    'entropy',
    'shape_simplicity',
    'p_cat',
    'p0',
    'p1',
]

csv_data = []
for o in output:
    row = [
        'default_user_id',
        o['alignment_type'], ##gtc_sc_group, 1=under, 2=over, 3=partial',
        o['iou'],
        o['entropy'],
        0, ## shape_simplicity (not used) 
        2, ## p_cat, aka the bin of the parameter (p0,p1)
        0, ## p0,
        1, ## p1,
    ]
    csv_data.append(row)

In [48]:
csv_data

[['default_user_id', -1, 0.6984968185424805, 10.633356094360352, 0, 2, 0, 1],
 ['default_user_id', -1, 0.7278337478637695, 10.684083938598633, 0, 2, 0, 1],
 ['default_user_id', -1, 0.4878067672252655, 10.509113311767578, 0, 2, 0, 1],
 ['default_user_id', -1, 0.5486190915107727, 10.552360534667969, 0, 2, 0, 1],
 ['default_user_id', -1, 0.7750210762023926, 10.660246849060059, 0, 2, 0, 1],
 ['default_user_id', 3, 0.44301387667655945, 10.607051849365234, 0, 2, 0, 1],
 ['default_user_id', -1, 0.6193311214447021, 10.567224502563477, 0, 2, 0, 1],
 ['default_user_id', -1, 0.5985008478164673, 10.714741706848145, 0, 2, 0, 1],
 ['default_user_id', -1, 0.0, 10.559011459350586, 0, 2, 0, 1],
 ['default_user_id', -1, 0.19250993430614471, 10.613861083984375, 0, 2, 0, 1],
 ['default_user_id', 3, 0.2430146187543869, 10.384044647216797, 0, 2, 0, 1],
 ['default_user_id', -1, 0.5120531320571899, 10.589523315429688, 0, 2, 0, 1],
 ['default_user_id', -1, 0.0, 10.65661334991455, 0, 2, 0, 1],
 ['default_user_i

In [49]:
import csv

with open('test_data.csv', 'w') as f:
      
    write = csv.writer(f)
    write.writerow(column_names)
    write.writerows(csv_data)

In [None]:
# from torchvisionsion.utils import save_image
# save_image(transformed_image, 'image_test.png')