In [1]:
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from numpy import *
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
from attvis.utils.metrices import *

from attvis.utils import render
from attvis.utils.saver import Saver
from attvis.utils.iou import IoU

from attvis.data.Imagenet import Imagenet_Segmentation
from maskgen.models.random_mask import RandomMaskSaliency

# from baselines.ViT.ViT_explanation_generator import Baselines, LRP
# from baselines.ViT.ViT_new import vit_base_patch16_224
# from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
# from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

import torch.nn.functional as F

plt.switch_backend('agg')


# hyperparameters
num_workers = 0
batch_size = 1

cls = ['airplane',
       'bicycle',
       'bird',
       'boat',
       'bottle',
       'bus',
       'car',
       'cat',
       'chair',
       'cow',
       'dining table',
       'dog',
       'horse',
       'motobike',
       'person',
       'potted plant',
       'sheep',
       'sofa',
       'train',
       'tv'
       ]

alpha = 2

imagenet_seg_path = "data/gtsegs_ijcv.mat"
# method = 'ig'
# method = 'rise'
# method = 'random'
method = 'ours'
thr = 0.

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")


# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

ds = Imagenet_Segmentation(imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

# # Model
# model = vit_base_patch16_224(pretrained=True).cuda()
# baselines = Baselines(model)

# # LRP
# model_LRP = vit_LRP(pretrained=True).cuda()
# model_LRP.eval()
# lrp = LRP(model_LRP)

# # orig LRP
# model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
# model_orig_LRP.eval()
# orig_lrp = LRP(model_orig_LRP)
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig, TrainingArguments, Trainer

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)

model = lambda x: pred_model(pixel_values=x).logits

from captum.attr import IntegratedGradients

ig = IntegratedGradients(lambda x: torch.softmax(pred_model(pixel_values=x).logits, dim=-1))

random_mask = RandomMaskSaliency(model, num_classes=1000)

from maskgen.models.vision_maskgen_model9 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=config.hidden_size, num_classes=config.num_labels)
mask_gen_model.to(device)
# mask_gen_model.load_state_dict(torch.load('trained/vision_maskgen_model3/mask_gen_model_2_90.pth'))
# mask_gen_model.load_state_dict(torch.load('mask_gen_model/mask_gen_model_final_9_195.pth'))
mask_gen_model.load_state_dict(torch.load('mask_gen_model/mask_gen_model_1_150.pth'))
# mask_gen_model.load_state_dict(torch.load('trained/mask_gen_model12/mask_gen_model_2_90.pth'))
mask_gen_model.eval()


metric = IoU(2, ignore_index=-1)

iterator = tqdm(dl)

pred_model.eval()


def compute_pred(output):
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    # pred[0, 0] = 282
    # print('Pred cls : ' + str(pred))
    T = pred.squeeze().cpu().numpy()
    T = np.expand_dims(T, 0)
    T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    Tt = T.cuda()

    return Tt


def eval_batch(image, labels, evaluator, index):
    # evaluator.zero_grad()
    # Save input image
    # if args.save_img:
    #     img = image[0].permute(1, 2, 0).data.cpu().numpy()
    #     img = 255 * (img - img.min()) / (img.max() - img.min())
    #     img = img.astype('uint8')
    #     Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
    #     Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
    #         os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))

    image.requires_grad = True
    # print("image", image.shape)

    image = image.requires_grad_()
    logits = evaluator(image) #.logits

    # outputs = pred_model(image, output_hidden_states=True)
    # logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    # print("predicted_class_idx", predicted_class_idx)
    
    # segmentation test for the rollout baseline
    if method == 'ig':
        # Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
        Res = ig.attribute(image.cuda(), target=predicted_class_idx, n_steps=200).sum(dim=1)
        Res =  F.avg_pool2d(Res, kernel_size=16, stride=16)
        # print("Res",Res.shape)
        Res = Res.reshape(batch_size, 1, 14, 14)

    elif method == "rise":
        Res = random_mask.attribute_img(image.cuda(),
                                image_size=config.image_size, 
                                patch_size=config.patch_size, 
                                n_samples=100, 
                                mask_prob=0.5)
        Res = Res.reshape(batch_size, 1, 14, 14)
    
    elif method == 'ours':
        Res = mask_gen_model.attribute_img(image.cuda()).reshape(batch_size, 1, 14, 14)

    elif method == "random":
        Res = torch.rand(batch_size, 1, 14, 14).cuda()


    if method != 'full_lrp':
        # interpolate to full image size (224,224)
        Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
    
    # threshold between FG and BG is the mean    
    Res = (Res - Res.min()) / (Res.max() - Res.min())

    ret = Res.mean()

    Res_1 = Res.gt(ret).type(Res.type())
    Res_0 = Res.le(ret).type(Res.type())

    Res_1_AP = Res
    Res_0_AP = 1-Res

    Res_1[Res_1 != Res_1] = 0
    Res_0[Res_0 != Res_0] = 0
    Res_1_AP[Res_1_AP != Res_1_AP] = 0
    Res_0_AP[Res_0_AP != Res_0_AP] = 0


    # TEST
    pred = Res.clamp(min=thr) / Res.max()
    pred = pred.view(-1).data.cpu().numpy()
    target = labels.view(-1).data.cpu().numpy()
    # print("target", target.shape)

    output = torch.cat((Res_0, Res_1), 1)
    output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)


    # Evaluate Segmentation
    batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
    batch_ap, batch_f1 = 0, 0

    # Segmentation resutls
    # print("output", output.shape)
    # print("ap labels", labels.shape)
    correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
    inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
    batch_correct += correct
    batch_label += labeled
    batch_inter += inter
    batch_union += union
    # print("output", output.shape)
    # print("ap labels", labels.shape)
    # ap = np.nan_to_num(get_ap_scores(output, labels))
    ap = np.nan_to_num(get_ap_scores(output_AP, labels))
    f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
    batch_ap += ap
    batch_f1 += f1

    return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target


total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
total_ap, total_f1 = [], []

predictions, targets = [], []
for batch_idx, (image, labels) in enumerate(iterator):

    if method == "blur":
        images = (image[0].cuda(), image[1].cuda())
    else:
        images = image.cuda()
    labels = labels.cuda()
    # print("image", image.shape)
    # print("lables", labels.shape)

    correct, labeled, inter, union, ap, f1, pred, target = eval_batch(images, labels, model, batch_idx)

    predictions.append(pred)
    targets.append(target)

    total_correct += correct.astype('int64')
    total_label += labeled.astype('int64')
    total_inter += inter.astype('int64')
    total_union += union.astype('int64')
    total_ap += [ap]
    total_f1 += [f1]
    pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
    IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
    mIoU = IoU.mean()
    mAp = np.mean(total_ap)
    mF1 = np.mean(total_f1)
    iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))

predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
pr, rc, thr = precision_recall_curve(targets, predictions)
# np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
# np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)

plt.figure()
plt.plot(rc, pr)
# plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method)))

# txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
# fh = open(txtfile, 'w')
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
print("Mean F1 over %d classes: %.4f\n" % (2, mF1))

# fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
# fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
# fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
# fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
# fh.close()


  from .autonotebook import tqdm as notebook_tqdm
pixAcc: 0.8218, mIoU: 0.6679, mAP: 0.9033, mF1: 0.4778:   8%|▊         | 337/4276 [01:33<17:45,  3.70it/s]