# Auditing Vision Models using Language

AI is becoming more pervasively used in various social contexts. As a result, it is important to understand the decisions that these models make and how they affect people. In this notebook, we will explore how to audit the decisions of a vision model using language. We will use the [DRML](https://arxiv.org/abs/1905.13677) method to audit the decisions of a vision model. We will use the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset to train a vision model and the [IMDB](https://www.imdb.com/interfaces/) dataset to train a language model. We will then use the language model to audit the decisions of the vision model.

## Load CelebA dataset and extract features

In [1]:
import json
import os

import clip
import torch
from torchvision.datasets import ImageNet

from datasets import ImageDataset, TextDataset, create_dataloader
from trainer import extract_features
from utils import openai_imagenet_classes, openai_imagenet_template


def filter_name(name: str) -> str:
    return "".join([c for c in name.lower() if c.isalnum()])


def extract_features_others(model_name: str, dataset: str):
    clip_model, transform = clip.load(name=model_name, device="cuda")
    clip_model = clip_model.float()

    data = [
        json.loads(line)
        for line in open(
            f"../../data/{dataset}/processed_attribute_dataset/attributes.jsonl"
        )
    ]
    for item in data:
        item["label"] = 0

    image_dataset = ImageDataset(data)
    image_dataloader = create_dataloader(
        dataset=image_dataset,
        modality="image",
        transform=transform,
        shuffle=False,
        batch_size=1024,
        num_workers=16,
    )
    image_features = extract_features(
        dataloader=image_dataloader,
        clip_model=clip_model,
        modality="image",
        verbose=True,
    )

    torch.save(
        image_features,
        f"{dataset.lower()}_features_{filter_name(model_name)}.pt",
    )


extract_features_others(model_name="ViT-B/32", dataset="Waterbird")

Extracting features for image: 100%|██████████| 12/12 [00:29<00:00,  2.45s/it]
