In [1]:
import os
import json 
import logging

logging.basicConfig(
    filename='log/app.log',            # Specify the log file name
    level=logging.DEBUG,           # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(levelname)s - %(message)s'  # Set the log format
)

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import os
import transformers
from accelerate import Accelerator
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig, TrainingArguments, Trainer
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
import cv2
from datasets import load_dataset,load_metric
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

accelerator = Accelerator()
device = accelerator.device

  from .autonotebook import tqdm as notebook_tqdm


# Load Model

In [3]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)
# set to eval mode
pred_model.eval()

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    outputs = pred_model(**inputs, output_hidden_states=True)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", pred_model.config.id2label[predicted_class_idx])



Predicted class: golden retriever


In [4]:
from torch.utils.data import DataLoader

def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, _, test_ds = load_data()

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

test_ds.set_transform(preprocess)
train_ds.set_transform(preprocess)

# batch size is limited to 2, because n_steps could could huge memory consumption
batch_size = 1000
test_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)
# test_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

Repo card metadata block was not found. Setting CardData to empty.


In [7]:
import torch 
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm

hm_path = 'results/maskgen-vit/heatmap-5000.npy'
# hm_path = 'results/rise-vit/heatmap-5000.npy'
# hm_path = 'results/maskgen_model3-vit/heatmap-5000.npy'
# hm_path = 'results/attvis-vit/heatmap-5000.npy'
# hm_path = 'results/gradcam-vit/heatmap-5000.npy'
# hm_path = 'random'

def obtain_masks_on_topk(attribution, topk, mode='positive'):
    """ 
    attribution: [N, H_a, W_a]
    """
    H_a, W_a = attribution.shape[-2:]
    attribution = attribution.reshape(-1, H_a * W_a) # [N, H_a*W_a]
    attribution_perturb = attribution + 1e-6*torch.randn_like(attribution) # to avoid equal attributions (typically all zeros or all ones)
    
    attribution_size = H_a * W_a
    topk_scaled = int(topk * attribution_size / 100)
    if mode == 'positive':
        a, _ = torch.topk(attribution_perturb, k=topk_scaled, dim=-1)
        a = a[:, -1].unsqueeze(-1)
        mask = (attribution_perturb <= a).float()
    elif mode == 'negative':
        a, _ = torch.topk(attribution_perturb, k=topk_scaled, dim=-1, largest=False)
        a = a[:, -1].unsqueeze(-1)
        mask = (attribution_perturb >= a).float()
        # a, _ = torch.topk(attribution_perturb, k=topk_scaled, dim=-1)
        # a = a[:, -1].unsqueeze(-1)
        # mask = (attribution_perturb >= a).float()

    else:
        raise ValueError('Enter game mode either as positive or negative.')
    return mask.reshape(-1, H_a, W_a) # [N, H_a*W_a]


def obtain_masked_input_on_topk(x, attribution, topk, mode='positive'):
    """ 
    x: [N, C, H, W]
    attribution: [N, H_a, W_a]
    """
    mask = obtain_masks_on_topk(attribution, topk, mode)
    mask = mask.unsqueeze(1) # [N, 1, H_a, W_a]
    mask = F.interpolate(mask, size=x.shape[-2:], mode='nearest')

    masked_input = x * mask

    # mean_pixel = masked_input.sum(dim=(-1, -2), keepdim=True) / mask.sum(dim=(-1, -2), keepdim=True)
    mean_pixel = x.mean(dim=(-1, -2), keepdim=True)
    masked_input = masked_input + (1 - mask) * mean_pixel

    return masked_input

def load_heatmap(path='results/maskgen-vit/heatmap-5000.npy', batch_size=1000):
    if path == 'random':
        heatmap = np.random.rand(5000, 1, 14, 14)
        heatmap = torch.tensor(heatmap)
    else:
        heatmap = np.load(path)
        heatmap = torch.tensor(heatmap) # [N, 1, 14, 14]
    batches = torch.split(heatmap, batch_size, dim=0)
    return batches


topk = 20
# divide heatmap into batches of size batch_size
heatmap_batches = load_heatmap(path=hm_path, batch_size=1000)

total_acc = []
for topk in [10, 20, 30,  40,  50,  60,  70, 80,  90]:
    for idx, data in tqdm(enumerate(zip(test_dataloader, heatmap_batches))):
        pixel_values = data[0]['pixel_values'].to(device) # [N, C, H, W]
        attribution = data[1].to(device) # [N, 1, 14, 14]
        with torch.no_grad():
            pseudo_label = pred_model(pixel_values).logits.argmax(-1).view(-1)
            masked_input = obtain_masked_input_on_topk(pixel_values, attribution, topk, mode='positive')
            logits = pred_model(masked_input).logits
            preds = logits.argmax(-1).view(-1)
            acc = (pseudo_label == preds).float().mean().item()
            print(acc)
            total_acc.append(acc)
            break
        


0it [00:13, ?it/s]


0.8940000534057617


0it [00:13, ?it/s]


0.8080000281333923


0it [00:13, ?it/s]


0.734000027179718


0it [00:13, ?it/s]


0.6610000133514404


0it [00:13, ?it/s]


0.5640000104904175


0it [00:13, ?it/s]


0.453000009059906


0it [00:13, ?it/s]


0.3370000123977661


0it [00:13, ?it/s]


0.19200000166893005


0it [00:13, ?it/s]

0.07500000298023224





In [8]:
sum(total_acc) / len(total_acc)

0.5242222398519516