## 1. Load Libraries


In [1]:
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from numpy import *
import argparse
from PIL import Image
import imageio
import os
from tqdm import tqdm
from utilities.metrices import *

from utilities import render
from utilities.saver import Saver
from utilities.iou import IoU

from data.Imagenet import Imagenet_Segmentation


from baselines.ViT.ViT_explanation_generator import Baselines, LRP
from baselines.ViT.ViT_new import vit_base_patch16_224
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP

from baselines.ViT.ViT_new import deit_base_patch16_224
from baselines.ViT.ViT_LRP import deit_base_patch16_224 as deit_LRP
from baselines.ViT.ViT_orig_LRP import deit_base_patch16_224 as deit_orig_LRP

from baselines.ViT.ViT_new import swin_base_patch16_224
from baselines.ViT.ViT_LRP import swin_base_patch16_224 as swin_LRP
from baselines.ViT.ViT_orig_LRP import swin_base_patch16_224 as swin_orig_LRP


from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


# 2. Set random seed

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 42
set_seed(seed)

# 3. Set arguments, for experiment

In [3]:
"""
In class Args below, between the green # or Number Sign:

# STEP 1
Change self.model to one of the below:
-   'ViT'
-   'DeiT'
-   'Swin'

# STEP 2
Change self.method to one of the below :

- for "VTA" ->         self.method ='transformer_attribution'
- for "LRP" ->                      'full_lrp'
- for "raw_attn" ->                 'attn_last_layer'
- for "rollout" ->                  'rollout'
- for "GradCAM" ->                  'attn_gradcam'
- for "DDS+VTA" ->                  'ours'

# STEP 3
Run All cells from section 3 to the end.
After running the output will be visible at the end of the file or in the jobfile output

"""
##############################################################################################################################################

# Args
class Args:
    def __init__(self):
        self.arc = 'vgg'
        self.train_dataset = 'imagenet'
        self.model = 'ViT'
        self.method = 'attn_last_layer'
        self.att_strategy = 'PGD'
        self.thr = 0.0
        self.K = 1
        self.save_img = False
        self.no_ia = False
        self.no_fx = False
        self.no_fgx = False
        self.no_m = False
        self.no_reg = False
        self.is_ablation = False
        self.imagenet_seg_path = 'datasets/gtsegs_ijcv.mat'
##############################################################################################################################################
args = Args()

args.checkname = args.method + '_' + args.arc

alpha = 2

cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")


# hyperparameters
num_workers = 0
batch_size = 1

plt.switch_backend('agg')

# hyperparameters
num_workers = 0
batch_size = 1

cls = ['airplane',
       'bicycle',
       'bird',
       'boat',
       'bottle',
       'bus',
       'car',
       'cat',
       'chair',
       'cow',
       'dining table',
       'dog',
       'horse',
       'motobike',
       'person',
       'potted plant',
       'sheep',
       'sofa',
       'train',
       'tv'
       ]



# Define Saver
saver = Saver(args)
saver.results_dir = os.path.join(saver.experiment_dir, 'results')
if not os.path.exists(saver.results_dir):
    os.makedirs(saver.results_dir)
if not os.path.exists(os.path.join(saver.results_dir, 'input')):
    os.makedirs(os.path.join(saver.results_dir, 'input'))
if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
    os.makedirs(os.path.join(saver.results_dir, 'explain'))

args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
if not os.path.exists(args.exp_img_path):
    os.makedirs(args.exp_img_path)
args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
if not os.path.exists(args.exp_np_path):
    os.makedirs(args.exp_np_path)

# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,
])
test_lbl_trans = transforms.Compose([
    transforms.Resize((224, 224), Image.NEAREST),
])

ds = Imagenet_Segmentation(args.imagenet_seg_path,
                           transform=test_img_trans, target_transform=test_lbl_trans)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)

if args.model == 'ViT':
    # ViT Model
    model = vit_base_patch16_224(pretrained=True).cuda()
    model.eval()
    baselines = Baselines(model)

    # LRP
    model_LRP = vit_LRP(pretrained=True).cuda()
    model_LRP.eval()
    lrp = LRP(model_LRP)

    # orig LRP
    model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
    model_orig_LRP.eval()
    orig_lrp = LRP(model_orig_LRP)

elif args.model == 'DeiT':
    # DeiT Model
    model = deit_base_patch16_224(pretrained=True).cuda()
    model.eval()
    baselines = Baselines(model)

    # LRP
    model_LRP = deit_LRP(pretrained=True).cuda()
    model_LRP.eval()
    lrp = LRP(model_LRP)

    # orig LRP
    model_orig_LRP = deit_orig_LRP(pretrained=True).cuda()
    model_orig_LRP.eval()
    orig_lrp = LRP(model_orig_LRP)
    
elif args.model == 'Swin':
    # Swin Model
    model = swin_base_patch16_224(pretrained=True).cuda()
    model.eval()
    baselines = Baselines(model)

    # LRP
    model_LRP = swin_LRP(pretrained=True).cuda()
    model_LRP.eval()
    lrp = LRP(model_LRP)

    # orig LRP
    model_orig_LRP = swin_orig_LRP(pretrained=True).cuda()
    model_orig_LRP.eval()
    orig_lrp = LRP(model_orig_LRP)
    
metric = IoU(2, ignore_index=-1)

iterator = tqdm(dl)


  transforms.Resize((224, 224), Image.NEAREST),
  0%|          | 0/4276 [00:00<?, ?it/s]

In [4]:
import sys
sys.path.append("./guided_diffusion")
import argparse
import os

import numpy as np
import torch as th
import torch.distributed as dist

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    # model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)

def diffusion_defaults():
    """
    Defaults for image and classifier training.
    """
    return dict(
        learn_sigma=True,
        diffusion_steps=1000,
        noise_schedule="linear",
        timestep_respacing="250",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=False,
        rescale_learned_sigmas=False,
    )


def classifier_defaults():
    """
    Defaults for classifier models.
    """
    return dict(
        image_size=256,
        classifier_use_fp16=False,
        classifier_width=128,
        classifier_depth=2,
        classifier_attention_resolutions="32,16,8",  # 16
        classifier_use_scale_shift_norm=True,  # False
        classifier_resblock_updown=True,  # False
        classifier_pool="attention",
    )


def model_and_diffusion_defaults():
    """
    Defaults for image training.
    """
    res = dict(
        image_size=256,
        num_channels=256,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=64,
        attention_resolutions="32,16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=False,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_fp16=True,
        use_new_attention_order=False,
    )
    res.update(diffusion_defaults())
    return res


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=1,
        batch_size=4,
        use_ddim=False,
        model_path="./guided_diffusion/models/256x256_diffusion_uncond.pt",
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

args_diff = create_argparser().parse_args([])

dist_util.setup_dist()
logger.configure()

logger.log("creating model and diffusion...")
d_model, diffusion = create_model_and_diffusion(
    **args_to_dict(args_diff, model_and_diffusion_defaults().keys())
)
d_model.load_state_dict(
    dist_util.load_state_dict(args_diff.model_path, map_location="cpu")
)
d_model.to(dist_util.dev())
if args_diff.use_fp16:
    d_model.convert_to_fp16()
d_model.eval()
device = next(d_model.parameters()).device

shape = (1, 3, 256, 256)
steps =  1000
start = 0.0001
end = 0.02
trial_num = 2

def range_of_delta(beta_s, beta_e, steps):
    def delta_value(beta):
        return (beta/(1-beta))**(0.5)
    return (delta_value(beta_s), delta_value(beta_e))

def beta(t, steps, start, end):
    return (t-1)/(steps-1)*(end-start)+start

def add_noise(x, delta, opt_t, steps, start, end):
    return np.sqrt(1-beta(opt_t, steps, start, end))*(x + th.randn_like(x) * delta)

def get_opt_t(delta, start, end, steps):
    return np.clip(int(np.around(1+(steps-1)/(end-start)*(1-1/(1+delta**2)-start))), 0, steps)

# opt_t = get_opt_t(delta, start, end, steps)

def denoise(img, opt_t, steps, start, end, delta, direct_pred=False):
    # Extra line of code that again prevents mismatch because some data is on cpu and some on gpu
    #img = img.to(device).float()

    img_xt = add_noise(img, delta, opt_t, steps, start, end).unsqueeze(0).to(device)

    indices = list(range(opt_t))[::-1]
    from tqdm.auto import tqdm
    indices = tqdm(indices)
    img_iter = img_xt
    for i in indices:
        t = th.tensor([i]*shape[0], device=device)
        # t = t.to(device)
        with th.no_grad():
            out = diffusion.p_sample(
                d_model,
                img_iter,
                t,
                clip_denoised=args_diff.clip_denoised,
                denoised_fn=None,
                cond_fn=None,
                model_kwargs={},
            )
            img_iter = out['sample']
            if direct_pred:
                return out['pred_xstart']
    # img_iter = ((img_iter + 1) * 127.5).clamp(0, 255).to(th.uint8)
    # img_iter = img_iter.permute(0, 2, 3, 1)
    # img_iter = img_iter.contiguous()
    return img_iter
trans_to_256= transforms.Compose([
   transforms.Resize((256, 256)),])
trans_to_224= transforms.Compose([
   transforms.Resize((224, 224)),])
delta_range = range_of_delta(start, end, steps)


def drop_lowest_max_fuse(arr, ratio=10):
    arr_out = arr[0]
    arr_out[arr_out<np.percentile(arr_out, ratio)]=0
    for i in range(1,len(arr)):
        arr_new = arr[i]
        arr_new[arr_new<np.percentile(arr_new, ratio)]=0
        arr_out = np.maximum(arr_out,arr_new)
    return arr_out
def max_fuse(arr):
    arr_out = arr[0]
    for i in range(1,len(arr)):
        arr_out = np.maximum(arr_out,arr[i])
    return arr_out
def mean_fuse(arr):
    arr_out = arr[0]
    for i in range(1,len(arr)):
        arr_out = arr_out + arr[i]
    return arr_out/len(arr)
def normal(arr):
    return (arr - arr.min())/(arr.max() - arr.min())

Logging to /scratch-local/ytjun.9714013/openai-2025-01-31-11-23-14-007753
creating model and diffusion...


In [5]:
noise_level=7/255
def attack(image, model, noise_level,label_index=None,mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]):
    import torchattacks
    torch.backends.cudnn.deterministic = True
    atk = torchattacks.PGD(model, eps=noise_level, alpha=noise_level/5, steps=10)
    atk.set_normalization_used(mean, std)
    labels = torch.FloatTensor([0]*1000)
    if label_index == None:
        # with torch.no_grad():
        logits = model(image)
        label_index = logits.argmax()
        # print(label_index)

    labels[label_index] = 1
    labels = labels.reshape(1, 1000)
    adv_images = atk(image, labels.float())
    return adv_images

In [6]:
def compute_pred(output):
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    # pred[0, 0] = 282
    # print('Pred cls : ' + str(pred))
    T = pred.squeeze().cpu().numpy()
    T = np.expand_dims(T, 0)
    T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    Tt = T.cuda()

    return Tt


def eval_batch(image, labels, evaluator, index):
    evaluator.zero_grad()
    # Save input image
    if args.save_img:
        img = image[0].permute(1, 2, 0).data.cpu().numpy()
        img = 255 * (img - img.min()) / (img.max() - img.min())
        img = img.astype('uint8')
        Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
        Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
            os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))

    image.requires_grad = True

    image = image.requires_grad_()
    predictions = evaluator(image)
    
    # segmentation test for the rollout baseline
    if args.method == 'rollout':
        Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the LRP baseline (this is full LRP, not partial)
    elif args.method == 'full_lrp':
        Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224)
    
    # segmentation test for our method
    elif args.method == 'transformer_attribution':
        Res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the partial LRP baseline (last attn layer)
    elif args.method == 'lrp_last_layer':
        Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\
            .reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the raw attention baseline (last attn layer)
    elif args.method == 'attn_last_layer':
        Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\
            .reshape(batch_size, 1, 14, 14)
    
    # segmentation test for the GradCam baseline (last attn layer)
    elif args.method == 'attn_gradcam':
        Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
    
    elif args.method == 'ours':
        opt_t = get_opt_t(noise_level, start, end, steps)
        res_lists = []
        for trial in range(trial_num):
            seed = 42 + trial
            set_seed(seed)
            image_denoised = trans_to_224(
                denoise(trans_to_256(image.cuda()).squeeze(0), opt_t, steps, start, end, noise_level)
            ).detach().cpu()
            image_denoise_normal = image_denoised + torch.randn_like(image_denoised) * noise_level
            image_denoise_normal = torch.squeeze(image_denoise_normal)
            image_denoise_normal = torch.clamp(image_denoise_normal, -1, 1)

            res = lrp.generate_LRP(image_denoise_normal.cuda().unsqueeze(0), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
            res_lists.append(res.data.cpu().numpy())

        Res = drop_lowest_max_fuse(res_lists)
        Res = torch.tensor(Res).cuda()

    if args.method != 'full_lrp':
        # interpolate to full image size (224,224)
        Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
    
    # threshold between FG and BG is the mean    
    Res = (Res - Res.min()) / (Res.max() - Res.min())

    ret = Res.mean()

    Res_1 = Res.gt(ret).type(Res.type())
    Res_0 = Res.le(ret).type(Res.type())

    Res_1_AP = Res
    Res_0_AP = 1-Res

    Res_1[Res_1 != Res_1] = 0
    Res_0[Res_0 != Res_0] = 0
    Res_1_AP[Res_1_AP != Res_1_AP] = 0
    Res_0_AP[Res_0_AP != Res_0_AP] = 0


    # TEST
    pred = Res.clamp(min=args.thr) / Res.max()
    pred = pred.view(-1).data.cpu().numpy()
    target = labels.view(-1).data.cpu().numpy()
    # print("target", target.shape)

    output = torch.cat((Res_0, Res_1), 1)
    output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)

    if args.save_img:
        # Save predicted mask
        mask = F.interpolate(Res_1, [64, 64], mode='bilinear')
        mask = mask[0].squeeze().data.cpu().numpy()
        # mask = Res_1[0].squeeze().data.cpu().numpy()
        mask = 255 * mask
        mask = mask.astype('uint8')
        imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)

        relevance = F.interpolate(Res, [64, 64], mode='bilinear')
        relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
        # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
        hm = np.sum(relevance, axis=-1)
        maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
        imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)

    # Evaluate Segmentation
    batch_inter, batch_union, batch_pix_correct, batch_label = 0, 0, 0, 0
    batch_ap = 0

    # Segmentation results
    pix_correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
    inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
    batch_pix_correct += pix_correct
    batch_label += labeled
    batch_inter += inter
    batch_union += union
    # print("output", output.shape)
    # print("ap labels", labels.shape)

    # Calculate average precision
    ap = np.nan_to_num(get_ap_scores(output_AP, labels))
    batch_ap += ap

    return batch_pix_correct, batch_label, batch_inter, batch_union, batch_ap, pred, target


# 4.Calculate metrics

In [7]:
total_inter, total_union, total_pix_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
total_ap, total_acc = [], []

predictions, targets = [], []
for batch_idx, (image, labels) in enumerate(iterator):
    if args.method == "blur":
        images = (image[0].cuda(), image[1].cuda())
    else:
        images = image.cuda()
    labels = labels.cuda()   
     
    if args.att_strategy == 'PGD':
        images = attack(images.cuda(), model, noise_level)
    
    # print("image", image.shape)
    # print("lables", labels.shape)

    correct, labeled, inter, union, ap, pred, target = eval_batch(images, labels, model, batch_idx)

    predictions.append(pred)
    targets.append(target)
    total_pix_correct += correct.astype('int64')
    total_label += labeled.astype('int64')
    total_inter += inter.astype('int64')
    total_union += union.astype('int64')
    total_ap += [ap]
    pixAcc = np.float64(1.0) * total_pix_correct / (np.spacing(1, dtype=np.float64) + total_label)
    IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
    mIoU = IoU.mean()
    mAp = np.mean(total_ap)
    iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f' % (pixAcc, mIoU, mAp))

pixAcc: 0.6433, mIoU: 0.4285, mAP: 0.7797: 100%|██████████| 4276/4276 [48:24<00:00,  1.47it/s]


In [8]:
txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
fh = open(txtfile, 'w')
print("Pixel-wise Accuracy: %.2f\n" % (pixAcc))
print("Mean IoU over %d classes: %.2f\n" % (2, mIoU))
print("Mean AP over %d classes: %.2f\n" % (2, mAp))

fh.write("Pixel-wise Accuracy: %.2f\n" % (pixAcc))
fh.write("Mean IoU over %d classes: %.2f\n" % (2, mIoU))
fh.write("Mean AP over %d classes: %.2f\n" % (2, mAp))
fh.close()

Pixel-wise Accuracy: 0.64

Mean IoU over 2 classes: 0.43

Mean AP over 2 classes: 0.78

