In [None]:
import random
from importlib import reload
import os
import json


import numpy as np
import torch
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product as iter_product

import src, src.debias, src.models, src.ranking, src.datasets, src.data_utils

if torch.cuda.device_count() > 1:
    use_device_id = int(input(f"Choose cuda index, from [0-{torch.cuda.device_count() - 1}]: ").strip())
else:
    use_device_id = 0
use_device = "cuda:" + str(use_device_id) if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
    input("CUDA isn't available, so using cpu. Please press any key to confirm this isn't an error: \n")
print("Using device", use_device)
torch.cuda.set_device(use_device_id)

with open(src.PATHS.TRAINED_MODELS.METADATA, mode="r") as _runs_metafile:
    runs_metadata = json.load(_runs_metafile)

clip_arch = "openai/CLIP/ViT-B/16"

run_id = "91"
run_metadata = runs_metadata[run_id]
model_save_name = f"best_ndkl_oai-clip-vit-b-16_neptune_run_OXVLB-91_model_e{run_metadata['epoch']}_step_{run_metadata['step']}.pt"
n_debias_tokens = 2

models = []
model_aliases = []
with torch.cuda.device(use_device_id):
    for k in (True, False):
        model, preprocess, tokenizer, model_alias = src.models.DebiasCLIP.from_cfg(src.Dotdict({
            "CLIP_ARCH": clip_arch, "DEVICE": use_device, "num_debias_tokens": n_debias_tokens if k else 0
        }))
        if k:
            model.load_state_dict(
                torch.load(os.path.join(src.PATHS.TRAINED_MODELS.BASE, model_save_name), map_location=use_device),
                strict=True)
        models.append(model.eval().to(use_device))
        model_aliases.append(model_alias)

tmodel, umodel = models


In [None]:
from_topk = 10
return_k = 10

prompts = ["a photo of a smart person"] # example prompt
dataset = "FairFace"

ds = getattr(src.datasets, dataset)(lazy=True, _n_samples=None, transforms=preprocess, mode="val")
dl = DataLoader(ds, batch_size=256, shuffle=False, num_workers=8)

with torch.no_grad():
    prompt_embeds = torch.stack([model.encode_text(tokenizer(prompts).to(use_device)) for model in models])
    prompt_embeds /= torch.norm(prompt_embeds, dim=-1, keepdim=True)

    img_embeds = []
    for model, trained in zip(models, (True, False)):
        img_embeds.append(src.datasets.compute_img_embeddings(ds, model, (f"train_run_{run_id}_" if trained else "")+model_alias, device=use_device))
    img_embeds = torch.stack(img_embeds)
    img_embeds /= torch.norm(img_embeds, dim=-1, keepdim=True)

import random
rng = random.Random()
res = {}
for i, prompt in enumerate(prompts):
    res[prompt] = {}
    for j, trained in enumerate(("trained", "untrained")):
        prompt_embed = prompt_embeds[j][i].to(use_device)
        img_embed = img_embeds[j].to(use_device)
        top_indices = (prompt_embed @ img_embed.T).topk(from_topk).indices.cpu().tolist()
        print(prompt, trained, (prompt_embed@img_embed.T).min())
        res[prompt][trained] = rng.sample(top_indices, return_k)

In [None]:
from PIL import Image
for prompt, pvals in res.items():
    print(prompt)
    for trained, tvals in pvals.items():
        print("\t", prompt)
        n_female = 0
        for inx in tvals:
            sample = ds[inx]
            image = Image.open(ds._img_fnames[inx])
            print(sample.gender)
            if sample.gender == "Female": n_female += 1
            display(image)
        print(f"\t\t% deviation from parity: {abs(0.5-(n_female/return_k)):.1%}")