In [10]:
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig
from maskgen.utils import get_preprocess, collate_fn, load_imagenet
# from maskgen.utils.img_utils import plot_overlap_np
from torch.utils.data import DataLoader
import torch
import json
import os
from tqdm import tqdm
import numpy as np
from typing import Dict, Any

config = {
        "pretrained_name": "google/vit-base-patch16-224",
        "results_path": "/scratch365/dpan/new_results/gradshap",
        "max_samples": 100,
        "dataset_split": "tiny",
        "num_samples": 1000,
        "batch_size":1,
}

In [2]:
from maskgen.utils.model_utils import get_pred_model

# Create results directory if it doesn't exist
if not os.path.exists(config['results_path']):
    os.makedirs(config['results_path'])

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models and processor
pretrained_name = config['pretrained_name']
processor, target_model = get_pred_model(pretrained_name, device)

print("target model loaded")

target model loaded


In [3]:
from maskgen.utils.image_utils import get_image_example


image = get_image_example(2)

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    img = inputs['pixel_values']
    img = img.to(device)
    predicted_class_idx = target_model(img).logits.argmax(-1).item()
    secondary_class_idx = target_model(img).logits.argsort(descending=True)[0][1].item()

label = predicted_class_idx
# label = secondary_class_idx
label = torch.tensor([label]).to(device)
print("Predicted class:", target_model.config.id2label[predicted_class_idx])

Predicted class: Siamese cat, Siamese


In [4]:
from maskgen.baselines.gradshap import GradShapAnalyzer, downsample_attribution

image = get_image_example(2)
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values.to(device)

# Initialize analyzer with wrapped model
analyzer = GradShapAnalyzer(
    model=target_model,
    device=device
)

# Get attribution for single image
attribution_map, pred_class, confidence = analyzer.get_attribution(pixel_values)

print(f"Predicted class: {pred_class}")
print(f"Confidence: {confidence:.3f}")

Predicted class: 284
Confidence: 0.964


In [5]:
from maskgen.utils.img_utils import plot_overlap_np
from maskgen.utils.data_utils import get_imagenet_dataloader

# get dataloader
dataloader = get_imagenet_dataloader(split='tiny', 
                                    batch_size=config['batch_size'], 
                                    processor=processor, 
                                    shuffle=False,
                                    num_samples=config['num_samples'])

Repo card metadata block was not found. Setting CardData to empty.


In [6]:
all_inputs = []
all_heatmaps = []

for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Processing batches"):
    pixel_values = batch['pixel_values'].to(device)
    labels = batch['labels'].to(device)
    attribution_map, pred_class, confidence = analyzer.get_attribution(pixel_values)
    attribution_map = downsample_attribution(attribution_map, patch_size=16)
    attribution_map = np.expand_dims(attribution_map, axis=0)

    inputs_np = pixel_values.cpu().numpy()
    heatmap_np = attribution_map
    all_inputs.append(inputs_np)
    all_heatmaps.append(heatmap_np)

all_inputs = np.concatenate(all_inputs, axis=0)
all_heatmaps = np.concatenate(all_heatmaps, axis=0)
    

Processing batches:   1%|▏         | 13/1000 [00:03<04:50,  3.39it/s]


KeyboardInterrupt: 

In [9]:
heatmap_np

array([[[-1.9054529e-03,  1.5604202e-04,  1.4694926e-04,  1.7749018e-04,
          2.3441883e-03,  1.3592049e-03,  4.5233557e-04, -2.4134452e-03,
          4.8402196e-04, -9.8332076e-04,  2.9184911e-03,  2.8672721e-03,
          3.3591727e-03,  2.2763824e-03],
        [ 1.5924835e-03,  6.2206975e-04,  3.7740611e-03,  1.1426775e-04,
          1.1477394e-03,  3.1292683e-04,  3.0242073e-04,  2.6950052e-03,
          4.5743445e-04,  1.9871923e-03,  3.2793727e-05,  1.2660279e-03,
         -1.2142060e-03,  2.5763519e-03],
        [ 4.2918284e-04, -3.7089051e-04,  1.0725774e-03, -8.8884583e-04,
         -2.0371848e-03,  7.0266845e-04,  2.4799481e-03,  1.1165568e-04,
          4.1146944e-03,  8.2621612e-03,  2.0356448e-03,  6.3023518e-04,
         -1.0789346e-03,  1.4509135e-03],
        [-2.5765563e-05, -6.9733697e-04, -5.3256151e-04, -9.2111941e-04,
         -2.0169171e-03, -1.6627368e-03, -1.2189907e-03,  4.2291026e-04,
         -2.1550287e-03,  1.3859791e-03,  1.0740218e-03,  7.2542857e-04

In [7]:
from maskgen.utils.save_utils import save_pixel_heatmap_pairs

save_path = config['results_path']
if not os.path.exists(save_path):
    os.makedirs(save_path)
save_path = os.path.join(save_path, 'pixel_heatmap_pairs.npz')
save_pixel_heatmap_pairs(all_inputs, all_heatmaps, save_path)