In [None]:
from importlib import reload

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product as iter_product

import src, src.debias, src.models, src.ranking, src.datasets, src.data_utils

if torch.cuda.device_count() > 1:
    use_device_id = int(input(f"Choose cuda index, from [0-{torch.cuda.device_count()-1}]: ").strip())
else: use_device_id = 0
use_device = "cuda:"+str(use_device_id) if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
    input("CUDA isn't available, so using cpu. Please press any key to confirm this isn't an error: \n")
print("Using device", use_device)
torch.cuda.set_device(use_device_id)

cfg = src.Dotdict()
# training
train_cfg = src.Dotdict()
cfg.train = train_cfg
train_cfg.NEPTUNE_PROJNAME = "oxai-vlb-ht22/OxAI-Vision-Language-Bias" # None if don't use
train_cfg.N_EPOCHS = 10
train_cfg.BATCH_SZ = 64
train_cfg.NUM_WORKERS = 6 # 0 for auto
train_cfg.LOG_EVERY = 10
train_cfg.DEVICE = use_device
#train_cfg.DATASET_NAME = "FairFace"
train_cfg.DATASET_SUBSAMPLE = 1.0 # None or 1.0 for full
train_cfg.PERF_STOPPING_DECREASE = 0.8
train_cfg.PERF_EVALS = ["cifar100", "flickr1k"]#["flickr1k", "cifar100"] # cifar100, flickr1k.
train_cfg.EVAL_EVERY = 0.25 # In epochs
train_cfg.BIAS_EVAL_SUBSAMPLE = 1.0
#train_cfg.BIAS_EVAL_DATASET_NAME = "FairFace"
train_cfg.BIAS_EVALS = ["ndkl", "maxskew"] # ndkl and min/maxskew supported

# architecture
debias_cfg = src.Dotdict()
cfg.debias = debias_cfg
debias_cfg.CLIP_ARCH = "openai/CLIP/ViT-B/16"
debias_cfg.DEBIAS_TYPE = "adv"
#debias_cfg.NUM_DEBIAS_TOKENS = 2
#debias_cfg.N_TRAIN_TXT_LAYERS = 0
#debias_cfg.N_TRAIN_VID_LAYERS = 0
#debias_cfg.FREEZE_PROJ = True
debias_cfg.DEBIAS_TOKEN_INIT = ["the"]  # how to initialize the debiasing tokens
debias_cfg.DEBIAS_POS = "prepend" # Prepending or appending
debias_cfg.DEVICE = use_device

# bias prompts
_train_prompts = ["a photo of a {} person"]
_test_prompts = ["a photo of a {} person", "a cropped photo of a {} face"]
_paired_prompts = [('good'       , 'evil'),
                   ('smart'      , 'dumb'),
                   ('attractive' , 'unattractive'),
                   ('lawful'     , 'criminal'),
                   ('friendly'   , 'unfriendly'),
                   ('clever'     , 'stupid'),
                   ('heroic'     , 'villainous'),
                   ('hardworking', 'lazy'),
                   ('trustworthy', 'dangerous'),
                   ('caring'     , 'nasty')]
_prompts_n_train = len(_paired_prompts)
prompt_cfg = src.debias.prepare_prompt_cfg(debias_cfg.DEBIAS_CLASS, _paired_prompts, _train_prompts, _test_prompts, _prompts_n_train, test_on_train=False)

# optimisation
optim_cfg = src.Dotdict()
cfg.optim = optim_cfg
optim_cfg.CL_ADV_TRAIN_SWITCH = 10
optim_cfg.ADV_HIDDEN_SIZE = 32
optim_cfg.ADV_N_INPUT = prompt_cfg.N_TRAIN
optim_cfg.ADV_DEVICE = use_device
optim_cfg.ADV_LR = 0.00002
optim_cfg.CL_LR = 2e-5 # 2e-6
optim_cfg.L_REG_TYPE = 2 # 1, 2, is p for L_p loss
optim_cfg.L_REG_WEIGHT = 0
optim_cfg.N_ADV_INIT_EPOCHS = 0 # 2


# [n_debias_tokens]
all_combs = [
    [1],
]
n_instances = torch.cuda.device_count()
all_combs = list(map(list, np.array_split(all_combs, n_instances)))
with torch.cuda.device(use_device_id):
    # Note that all the parameters are strings here
    for n_db_tkz in all_combs[use_device_id]:
        print(f"Experiment with: {n_db_tkz} tokens")
        debias_cfg.NUM_DEBIAS_TOKENS = int(n_db_tkz)
        debias_cfg.FREEZE_PROJ = True
        train_cfg.DATASET_NAME = "FairFace"
        train_cfg.BIAS_EVAL_DATASET_NAME = "FairFace"
        debias_cfg.DEBIAS_CLASS = "gender"
        debias_cfg.N_TRAIN_TEXT_LAYERS = 0
        debias_cfg.N_TRAIN_VID_LAYERS = 0
        """Train debiasing."""
        src.debias.run_debiasing(debias_cfg, train_cfg, prompt_cfg, optim_cfg)
    


In [None]:
assert False, "don't run following cells when running full notebook to train"

In [None]:

reload(src.debias)
for bias_eval in train_cfg.BIAS_EVALS:
    src.debias.plot_comparison_rankmetrics(prompt_cfg, debias_exp_res, bias_eval)