In [1]:
from maskgen.utils.model_utils import get_pred_model
from fastshap import ImageSurrogate
from accelerate import Accelerator 
from maskgen.baselines.unet import UNet
from fastshap import FastSHAP
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer2d, KLDivLoss, DatasetInputOnly


accelerator = Accelerator()
device = accelerator.device

config = {
    'pretrained_name': 'google/vit-base-patch16-224',
    "results_path": "/scratch365/dpan/new_results/fastshap_imagenet",
    "csv_path": "./new_results/fastshap_imagenet",
}


In [7]:
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Resize
from transformers import ViTImageProcessor

def create_transforms(processor: ViTImageProcessor):
    """Create image transforms based on processor config."""
    normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
    
    if "height" in processor.size:
        size = (processor.size["height"], processor.size["width"])
        crop_size = size
    elif "shortest_edge" in processor.size:
        size = processor.size["shortest_edge"]
        crop_size = (size, size)
    
    return Compose([
        # RandomResizedCrop(crop_size),
        # RandomHorizontalFlip(),
        Resize(crop_size),
        ToTensor(),
        normalize,
    ])

def get_preprocess(processor):
    """Apply transforms across a batch."""
    transforms = create_transforms(processor)
    def preprocess(example_batch):
        example_batch["pixel_values"] = [
            transforms(image.convert("RGB")) 
            for image in example_batch["image"]
        ]
        # keep only pixel_values from dictionary
        example_batch = {key: example_batch[key] for key in ["pixel_values"]}
        return example_batch
    return preprocess

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [8]:
# Load Model
pretrained_name = config['pretrained_name']
processor, target_model = get_pred_model(pretrained_name, device)
vit_config = target_model.config



In [9]:
pred_model = lambda x: target_model(x).logits

In [10]:
imputer = ImageSurrogate(pred_model, width=224, height=224, superpixel_size=16)
explainer = UNet(n_classes=1000, num_down=4, num_up=0, num_convs=2).to(device)
fastshap = FastSHAP(explainer, imputer, link=nn.LogSoftmax(dim=1))

In [12]:
from maskgen.utils.data_utils import load_imagenet
import torch
from torch.utils.data import Dataset

def preprocess(example_batch):
    transforms = create_transforms(processor)
    example_batch["pixel_values"] = [
        transforms(image.convert("RGB")) 
        for image in example_batch["image"]
    ]
    # keep only pixel_values from dictionary
    example_batch = {key: example_batch[key] for key in ["pixel_values"]}
    return example_batch


class DatasetInputOnly(Dataset):
    '''
    A wrapper around a dataset object to ensure that only the first element is
    returned.

    Args:
      dataset: dataset object.
    '''

    def __init__(self, dataset):
        # assert isinstance(dataset, Dataset)
        self.dataset = dataset
        self.transforms = create_transforms(processor)

    def __getitem__(self, index):
        image = self.dataset[index]['image']
        pixel_values = self.transforms(image.convert("RGB"))
        return pixel_values

    def __len__(self):
        return len(self.dataset)


dataset = load_imagenet(split='tiny')
# dataset.set_transform(preprocess)
# dataset.set_format(type="torch", columns=["pixel_values"])

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
# # Set up datasets

fastshap_train = DatasetInputOnly(train_set)
fastshap_val = DatasetInputOnly(val_set)

# Train
fastshap.train(
    train_set,
    val_set,
    batch_size=128,
    num_samples=2,
    max_epochs=200,
    eff_lambda=1e-2,
    validation_samples=1,
    lookback=10,
    bar=True,
    verbose=True)

# Save explainer
explainer.cpu()
torch.save(explainer, 'cifar missingness explainer.pt')
explainer.to(device)

Repo card metadata block was not found. Setting CardData to empty.


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.JpegImagePlugin.JpegImageFile'>

In [3]:
from maskgen.utils.data_utils import load_imagenet

dataset = load_imagenet(split='tiny')


Repo card metadata block was not found. Setting CardData to empty.


In [4]:
dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375>,
 'label': 0}

In [5]:
import requests
from PIL import Image
from maskgen.utils.img_utils import plot_overlap_np
import torch

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
# url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
# url = "http://farm9.staticflickr.com/8490/8179481059_41be7bf062_z.jpg"
# url = "http://farm1.staticflickr.com/76/197438957_b20800e7cf_z.jpg"
# url = "http://farm3.staticflickr.com/2284/5730266001_7d051b01b7_z.jpg"
url = "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01491361_tiger_shark.JPEG?raw=true"
# url = "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n03000684_chain_saw.JPEG?raw=true"
# url = "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n04009552_projector.JPEG?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    img = inputs['pixel_values']
    img = img.to(device)
    predicted_class_idx = target_model(img).logits.argmax(-1).item()
    secondary_class_idx = target_model(img).logits.argsort(descending=True)[0][1].item()