In [None]:
import torch
from typing import Dict
import numpy as np
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, CLIPTextModelWithProjection, AutoProcessor
import pandas as pd
import pickle 
import seaborn as sns
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
from numpy.linalg import svd
from sklearn.decomposition import TruncatedSVD
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import yaml
import json
from sklearn.svm import LinearSVC
import argparse
from datasets import load_dataset, Dataset
from scipy.optimize import minimize
from kneed import KneeLocator
from sentence_transformers import util

In [None]:
# parser = argparse.ArgumentParser(description="Argument Parser")
# parser.add_argument('--config_path', type=str)
# args = parser.parse_args()

In [None]:
config = yaml.safe_load(open("exp_configs/ff_ff_stereotype_gender.yml"))
# config = yaml.safe_load(open(args.config_path))

print(config)

In [None]:
att_to_debias = config['att_to_debias']
debias_with_ground_truth = config['debias_with_ground_truth']
debias_with_predicted_labels = config['debias_with_predicted_labels']
reference_dataset_name = config['reference_dataset_name']
target_dataset_name = config['target_dataset_name']
model_ID = config['model_ID']
optimization_method = config['optimization_method']
query_type = config['query_type']
random_seed = config['random_seed']

if query_type == 'hair':
    QUERY_TYPE = 'celeba_features'
    query_classes = ['Blond_Hair', 'Black_Hair', 'Brown_Hair', 'Gray_Hair']

elif query_type == 'stereotype':
    QUERY_TYPE = 'criminal_justice'
    query_classes = get_query_classes(QUERY_TYPE)
else:
    raise('query type not implemented')

print(f"query_classes: {query_classes}")

dataset_map = {}
dataset_map['CelebaHQ_dialog'] = "data/celeba_hq_dialog_ff_race.jsonl"
dataset_map['FairFace'] = 'data/openai_clip_cit_large_patch14_fairface_vectorized.jsonl'
dataset_map['UTKFace'] = 'data/utk_center_cropped_ff_race.jsonl'

reference_ds_path = dataset_map[reference_dataset_name]
target_ds_path = dataset_map[target_dataset_name]

print(f"reference_dataset_name: {reference_dataset_name}, reference_ds_path: {reference_ds_path}")
print(f"target_dataset_name: {target_dataset_name}, target_ds_path: {target_ds_path}")

#
normalize = True
lam = 1000
#

if att_to_debias == 'race':
    if target_dataset_name == 'UTKFace':
        att_elements = ['White', 'Black', 'Asian', 'Indian', 'Latino Hispanic']
    else:
        att_elements = ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White']
elif att_to_debias == 'gender':
    att_elements = ['Male', 'Female']
else:
    print('{att_to_debias} not implemented')


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:

model_id = 'openai/clip-vit-large-patch14'

vl_model = CLIPModel.from_pretrained(model_ID).to(device)
processor = AutoProcessor.from_pretrained(model_ID)

def get_embeddings(input_text : list, clip_model, clip_processor, normalize=True):

    with torch.no_grad():
        inputs = clip_processor(text=input_text, return_tensors="pt").to(device)

        query_text_embedding = clip_model.get_text_features(**inputs)#.to('cpu').numpy()

    if normalize:
        query_text_embedding /= query_text_embedding.norm(dim=-1, keepdim=True)
    return query_text_embedding



In [None]:
embeddings_dataset = load_dataset("json", data_files=target_ds_path, split='train')
embeddings_dataset = embeddings_dataset.with_format("np", columns=["embedding"], output_all_columns=True)

In [None]:
split = embeddings_dataset.train_test_split(test_size=0.3)

In [None]:
split['train']

In [None]:

def utk_ethnicity_map(x):
    if x == 0:
        return 'White'
    if x == 1:
        return 'Black'
    if x == 2:
        return 'Asian'
    if x == 3:
        return 'Indian'
    if x == 4:
        return 'Latino Hispanic'



def load_embedding_dataset(ds_path):
    embeddings_dataset = load_dataset("json", data_files=ds_path, split='train')
    embeddings_dataset = embeddings_dataset.with_format("np", columns=["embedding"], output_all_columns=True)
    for i in range(5):
        fold_list = []
        for j in range(len(embeddings_dataset)):
            if np.random.rand() < 0.9:
                fold_list.append(1)
            else:
                fold_list.append(0)
        embeddings_dataset = embeddings_dataset.add_column(f'fold_{i}', fold_list)
    if 'Male' in embeddings_dataset.features.keys():
        # embeddings_dataset = embeddings_dataset.rename_column('Male', 'gender')
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"gender": 'Male' if x['Male'] == 1 else 'Female'})
    if 'utk_race' in embeddings_dataset.features.keys():
        # embeddings_dataset = embeddings_dataset.rename_column('Male', 'gender')
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"race": utk_ethnicity_map(x['utk_race'])})

    if normalize:
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"embedding": x['embedding'].reshape(-1)/np.linalg.norm(x['embedding'].reshape(-1))})
        embeddings_dataset.add_faiss_index(column="embedding")
    else:
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"embedding": x['embedding'].reshape(-1)})
        embeddings_dataset.add_faiss_index(column="embedding")
    return embeddings_dataset

def load_split_embedding_dataset(ds_path, test_split = 0.5):
    embeddings_dataset = load_dataset("json", data_files=ds_path, split='train')
    embeddings_dataset = embeddings_dataset.with_format("np", columns=["embedding"], output_all_columns=True)
    for i in range(5):
        fold_list = []
        for j in range(len(embeddings_dataset)):
            if np.random.rand() < 0.9:
                fold_list.append(1)
            else:
                fold_list.append(0)
        embeddings_dataset = embeddings_dataset.add_column(f'fold_{i}', fold_list)
    if 'Male' in embeddings_dataset.features.keys():
        # embeddings_dataset = embeddings_dataset.rename_column('Male', 'gender')
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"gender": 'Male' if x['Male'] == 1 else 'Female'})
    if 'utk_race' in embeddings_dataset.features.keys():
        # embeddings_dataset = embeddings_dataset.rename_column('Male', 'gender')
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"race": utk_ethnicity_map(x['utk_race'])})
    if normalize:
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"embedding": x['embedding'].reshape(-1)/np.linalg.norm(x['embedding'].reshape(-1))})
    else:
        embeddings_dataset = embeddings_dataset.map(
            lambda x: {"embedding": x['embedding'].reshape(-1)})
    split = embeddings_dataset.train_test_split(test_size=test_split)
    r_embeddings_dataset = split['train']
    t_embeddings_dataset = split['test']
    r_embeddings_dataset.add_faiss_index(column="embedding")
    t_embeddings_dataset.add_faiss_index(column="embedding")
    return r_embeddings_dataset, t_embeddings_dataset

In [None]:
if reference_dataset_name == target_dataset_name:
    reference_embeddings_dataset, target_embeddings_dataset = load_split_embedding_dataset(reference_ds_path, test_split = 0.5)
else:
    reference_embeddings_dataset = load_embedding_dataset(reference_ds_path)
    target_embeddings_dataset = load_embedding_dataset(target_ds_path)

In [None]:
def get_cos_neighbors(query_vec, embed_dataset, k = None):
    cos_scores = util.cos_sim(query_vec.astype(float), embed_dataset['embedding'].astype(float))
    if k is None:
        _k = len(embed_dataset)
    else:
        _k = k
    top_results = torch.topk(cos_scores, k=_k)
    topk_sim = top_results.values.cpu().numpy().reshape(-1)
    top_indices = top_results.indices.cpu().numpy()[0]

    if k is None:
        kn = KneeLocator([i for i in range(len(topk_sim))], topk_sim, curve='convex', direction='decreasing').knee
        print(kn)
        top_indices = top_indices[:kn]
        topk_sim = topk_sim[:kn]
    dist_scores = 1. - topk_sim
    neighbors = embed_dataset[top_indices]
    
    return dist_scores, neighbors

def get_embeddings(input_text : list, clip_model, clip_processor, normalize=True):
    # tokenized_query_text = clip_tokenizer(input_text, padding=True, return_tensors="pt").to(device)
    # with torch.no_grad():
    #     query_text_embedding = clip_model(**tokenized_query_text)['text_embeds']

    with torch.no_grad():
        inputs = clip_processor(text=input_text, return_tensors="pt", padding=True).to(device)

        query_text_embedding = clip_model.get_text_features(**inputs)#.to('cpu').numpy()

    if normalize:
        query_text_embedding /= query_text_embedding.norm(dim=-1, keepdim=True)
    return query_text_embedding

def get_proj_matrix(embeddings):
    tSVD = TruncatedSVD(n_components=len(embeddings))
    embeddings_ = tSVD.fit_transform(embeddings)
    basis = tSVD.components_.T

    # orthogonal projection
    proj = np.linalg.inv(np.matmul(basis.T, basis))
    proj = np.matmul(basis, proj)
    proj = np.matmul(proj, basis.T)
    proj = np.eye(proj.shape[0]) - proj
    return proj

def get_A(z_i, z_j):
    z_i = z_i[:, None]
    z_j = z_j[:, None]
    return np.matmul(z_i, z_i.T) + np.matmul(z_j, z_j.T) - np.matmul(z_i, z_j.T) - np.matmul(z_j, z_i.T)

def get_M(embeddings, S):
    d = embeddings.shape[1]
    M = np.zeros((d, d))
    for s in S:
        M  += get_A(embeddings[s[0]], embeddings[s[1]])
    return M / len(S)
# Define the objective function with additional parameter initial_e
def objective(e_star, initial_e):
    return float(-np.dot(e_star, initial_e))

# Define the constraints with additional parameters f_mean and m_mean
def eq_dist_constraint(e_star, y_mean, x_mean):
    return float(np.dot(e_star, y_mean) - np.dot(e_star, x_mean))

def norm_constraint(e_star):
    return float(np.dot(e_star, e_star) - 1)

def legrange_text(query_embedding, ref_dataset, spurious_label, spurious_class_list, num_neighbors, proj_matrix, normalize=True):

    if proj_matrix is not None:
        query_embedding = np.matmul(query_embedding, proj_matrix.T)

    if normalize:
        norm = np.linalg.norm(query_embedding, axis=-1, keepdims=True)
        # print(norm)
        query_embedding /= norm
    
    t_embed = query_embedding.reshape(-1)
    q_t = query_embedding.reshape(1,-1)
    ref_scores, ref_samples = get_cos_neighbors(query_embedding, ref_dataset, k = len(ref_dataset))

    

    ref_embed_array = np.asarray(ref_samples['embedding'])
    ref_spurious_array = np.asarray(ref_samples[spurious_label])

    spurious_anchor_class = spurious_class_list[0]

    if num_neighbors is None:
        _sim_scores = ref_scores[ref_spurious_array == spurious_anchor_class]
        _sim_scores = 1. - _sim_scores
        kn = KneeLocator([i for i in range(len(_sim_scores))], _sim_scores, curve='convex', direction='decreasing').knee
        s_k = kn
        print(kn)
    else:
        s_k = num_neighbors

    anchor_ref_embed_array = ref_embed_array[ref_spurious_array == spurious_anchor_class][:s_k]
    # if proj_matrix is not None:
    #     anchor_ref_embed_array = np.matmul(anchor_ref_embed_array, proj_matrix.T)  
    anchor_prototype = anchor_ref_embed_array.mean(axis=0)

    #initial guess 
    x0 = query_embedding.reshape(-1).astype(float)
    # Parameters
    x_mean = anchor_prototype.reshape(-1).astype(float)

    y_means = []

    for spurious_class in spurious_class_list[1:]:
        if num_neighbors is None:
            _sim_scores = ref_scores[ref_spurious_array == spurious_class]
            _sim_scores = 1. - _sim_scores
            kn = KneeLocator([i for i in range(len(_sim_scores))], _sim_scores, curve='convex', direction='decreasing').knee
            s_k = kn
            print(kn)
        else:
            s_k = num_neighbors
        print()
        s_ref_embed_array = ref_embed_array[ref_spurious_array == spurious_class][:s_k]
        # if proj_matrix is not None:
        #     s_ref_embed_array = np.matmul(s_ref_embed_array, proj_matrix.T)       
        s_prototype = s_ref_embed_array.mean(axis=0)
        y_means.append(s_prototype)
        
    # Define the constraints in dictionary form with additional parameters
    norm_con = {'type': 'eq', 'fun': norm_constraint}
    cons = [norm_con]
    for _y_mean in y_means:
        dist_con = {'type': 'eq', 'fun': eq_dist_constraint, 'args': (_y_mean, x_mean)}
        cons.append(dist_con)

    solution = minimize(objective, x0, args=(query_embedding.reshape(-1),), method='SLSQP', constraints=cons)
    e_star = solution.x
    e_star = e_star.reshape(1,-1)

    return e_star, x_mean, y_means



def log_with_eps(x, eps=1e-10):
    if x < eps:
        return np.log(eps)
    else:
        return np.log(x)

def max_skew(returned_samples, target_dist, spurious_label='gender', target_classes = [-1,1]):
    print(target_dist)
    maxskew = 0
    for cl in target_classes:
        p_y_ds = target_dist[cl]
        p_y_returned = (np.asarray(returned_samples[spurious_label])==cl).astype(int).mean()
        # print(f"p_y_returned: {p_y_returned}, p_y_ds: {p_y_ds}")
        candidate_skew = log_with_eps(p_y_returned/p_y_ds)# np.log(p_y_returned/p_y_ds) 
        if candidate_skew > maxskew:
            maxskew = candidate_skew
    return maxskew

def get_kl(returned_samples, target_dist, spurious_label='gender', target_classes = [-1,1]):
    kl = 0
    for cl in target_classes:
        p_y_ds = target_dist[cl]
        p_y_returned = (np.asarray(returned_samples[spurious_label])==cl).astype(int).mean()
        # print(f"p_y_returned: {p_y_returned}, p_y_ds: {p_y_ds}")
        kl += p_y_returned * log_with_eps(p_y_returned/p_y_ds)#np.log(p_y_returned/p_y_ds)
    return kl

def _cl_with_max_skew(returned_samples, target_dist, spurious_label='gender', target_classes = [-1,1]):
    # print(target_dist)
    maxskew = 0
    max_skew_cl = None

    for cl in target_classes:
        p_y_ds = target_dist[cl]
        p_y_returned = (np.asarray(returned_samples[spurious_label])==cl).astype(int).mean()
        # print(f"p_y_returned: {p_y_returned}, p_y_ds: {p_y_ds}")
        candidate_skew = log_with_eps(p_y_returned/p_y_ds)#np.log(p_y_returned/p_y_ds)
        if candidate_skew > maxskew:
            maxskew = candidate_skew
            max_skew_cl = cl
    if max_skew_cl is None:
        max_skew_cl = cl
    return max_skew_cl

def _cl_with_min_skew(returned_samples, target_dist, spurious_label='gender', target_classes = [-1,1]):
    # print(target_dist)
    minskew = 100000
    min_skew_cl = None
    for cl in target_classes:
        p_y_ds = target_dist[cl]
        p_y_returned = (np.asarray(returned_samples[spurious_label])==cl).astype(int).mean()
        # print(f"p_y_returned: {p_y_returned}, p_y_ds: {p_y_ds}")
        candidate_skew = log_with_eps(p_y_returned/p_y_ds)#np.log(p_y_returned/p_y_ds)
        if candidate_skew < minskew:
            minskew = candidate_skew
            min_skew_cl = cl
    if min_skew_cl is None:
        min_skew_cl = cl
    return min_skew_cl

def group_accuracy(retrieved_samples, q_class,  spurious_label, spurious_class_list):
    s_array = np.asarray(retrieved_samples[spurious_label])
    class_array = np.asarray(retrieved_samples[q_class])
    result_dict = {}
    for s_class in spurious_class_list:
        result_dict[s_class] = (class_array[s_array == s_class] == 1).mean()
    return result_dict

def auc_roc(q_embed, embed_dataset,  q_class, spurious_label, spurious_class_list, metric_is_distance = True):
    all_scores, all_samples = get_cos_neighbors(q_embed, embed_dataset, k = len(embed_dataset))
    

    s_array = np.asarray(all_samples[spurious_label])
    print(s_array.shape)
    _class_array = np.asarray(all_samples[q_class])
    print(_class_array)
    binary_class_array = (_class_array== 1 ).astype(int)
    # print(binary_class_array)
    score_array = np.asarray(all_scores)
    if metric_is_distance:
        score_array = -1*score_array
    # print(score_array.shape)
    result_dict = {}
    for s_class in spurious_class_list:
        s_binary_class_array = binary_class_array[s_array == s_class]
        if len(np.unique(s_binary_class_array)) != 1:

            s_score_array = score_array[s_array == s_class]
            result_dict[s_class] = roc_auc_score(s_binary_class_array, s_score_array)
    return result_dict

In [None]:
def get_worst_group_performance(method_metric_dict, higher_better=True):
    worst_class = None
    if higher_better:
        worst_metric = 100000
    else: 
        worst_metric = -100000
    for spurious_att in method_metric_dict.keys():  
        if  higher_better:
            if method_metric_dict[spurious_att] < worst_metric:
                worst_metric = method_metric_dict[spurious_att]
                worst_class = spurious_att
        else:
            if method_metric_dict[spurious_att] > worst_metric:
                worst_metric = method_metric_dict[spurious_att]
                worst_class = spurious_att
    return worst_metric, worst_class


def get_best_group_performance(method_metric_dict, higher_better=True):
    best_class = None
    if higher_better:
        best_metric = -100000
    else: 
        best_metric = 100000
    for spurious_att in method_metric_dict.keys():  
        if  higher_better:
            if method_metric_dict[spurious_att] > best_metric:
                best_metric = method_metric_dict[spurious_att]
                best_class = spurious_att
        else:
            if method_metric_dict[spurious_att] < best_metric:
                best_metric = method_metric_dict[spurious_att]
                best_class = spurious_att
    return best_metric, best_class


def relevency(returned_samples, q_class, spurious_label='gender', spurious_class_list = [-1,1]):
    result_dict = {}
    spurious_label_array = np.asarray(returned_samples[spurious_label])
    query_class_array = np.asarray(returned_samples[q_class])
    for cl in spurious_class_list: 
        samples_for_cl = query_class_array[spurious_label_array==cl]
        p_rel = (samples_for_cl==1).astype(int).mean()
        result_dict[cl] = p_rel
    return result_dict

def get_metrics(q_embedding, query_class, att_to_debias, K, spurious_att_prior, target_spurious_class_list, target_datasets, name='Vanilla', QUERY_IS_LABELED=True):
    result_d = {}
    for i in range(5):
        _result = {}
        _t_ds = target_datasets[i]
        
        _scores, _samples = get_cos_neighbors(q_embedding, _t_ds, k = K)
        # _scores, _samples = target_embeddings_dataset.get_nearest_examples(
        # "embedding", q_embedding, k=K)
        if QUERY_IS_LABELED:
            _auc_roc = auc_roc(q_embedding, _t_ds,  query_class, att_to_debias, spurious_class_list = target_spurious_class_list)
            worst_metric_val, worst_group = get_worst_group_performance(_auc_roc)
            best_metric_val, best_group = get_best_group_performance(_auc_roc)
            print(f"{name} worst group AUC ROC: {worst_metric_val}, worst group: {worst_group}")
            _result['worst_auc_roc_val'] = worst_metric_val
            _result['worst_auc_roc_group'] = worst_group

            _result['best_auc_roc_val'] = best_metric_val
            _result['best_auc_roc_group'] = best_group

            print(f"{name} gap for AUC ROC: {best_metric_val - worst_metric_val}")
            _result['auc_roc_gap'] = best_metric_val - worst_metric_val

            _relevency = relevency(_samples,  query_class, att_to_debias, spurious_class_list = target_spurious_class_list)
            worst_rel_val, worst_rel_group = get_worst_group_performance(_relevency)
            print(f"{name} worst group relevency: {worst_rel_val}, worst group: {worst_rel_group}")
            _result['worst_rel_val'] = worst_rel_val
            _result['worst_rel_group'] = worst_rel_group

        max_skew_prior = max_skew(_samples, spurious_att_prior, spurious_label = att_to_debias , target_classes=target_spurious_class_list)
        kl_prior = get_kl(_samples, spurious_att_prior, spurious_label = att_to_debias , target_classes=target_spurious_class_list)
        print(f"{name} Max Skew Prior: {max_skew_prior}")
        print(f"{name} KL Prior: {kl_prior}")
        
        

        _result['max_skew_prior'] = max_skew_prior
        _result['kl_prior'] = kl_prior
        result_d[f"fold_{i}"] = _result
    print()
    return result_d

In [None]:
target_datasets = {}
for i in range(5):
    t = np.asarray([i for i in range(len(target_embeddings_dataset))])
    _t_ds = Dataset.from_dict(target_embeddings_dataset[t[(target_embeddings_dataset[f'fold_{i}'] == 1)]])
    _t_ds = _t_ds.with_format("np", columns=["embedding"], output_all_columns=True)
    target_datasets[i] = _t_ds

In [None]:
if query_type == 'hair':
    query_is_labeled = True
else:
    query_is_labeled = False

ref_spurious_class_list = att_elements
target_spurious_class_list = att_elements

result_dict = {}
for query_class in query_classes:
    result_dict[query_class] = {}
    query_text = [instantiated_search_classes[query_class]['query']]
    print(query_class)
    print(query_text)


    spurious_prompt = instantiated_search_classes[query_class]['spurious_prompts']
    inclusive_candidate_prompt = instantiated_search_classes[query_class]['inclusive_candidate_prompts']
    exclusive_candidate_prompt = instantiated_search_classes[query_class]['exclusive_candidate_prompts']
    S = [[0,1]]

    print(spurious_prompt)
    print(inclusive_candidate_prompt)
    print(exclusive_candidate_prompt)

    spurious_att_array = np.asarray(target_embeddings_dataset[att_to_debias])
    ref_spurious_att_array = np.asarray(reference_embeddings_dataset[att_to_debias])

    spurious_att_prior = {}
    for spurious_att in target_spurious_class_list:
        spurious_att_prior[spurious_att] = spurious_att_array[spurious_att_array==spurious_att].shape[0]/spurious_att_array.shape[0]
    print(f'spurious att prior: {spurious_att_prior}')


    ref_spurious_att_prior = {}
    for r_spurious_att in ref_spurious_class_list:
        ref_spurious_att_prior[r_spurious_att] = ref_spurious_att_array[ref_spurious_att_array==r_spurious_att].shape[0]/ref_spurious_att_array.shape[0]
    print(f'ref spurious att prior: {ref_spurious_att_prior}')

    if query_is_labeled:
        eg_array = np.asarray(target_embeddings_dataset[query_class])
        conditional_spurious_att_prior = {}
        conditional_spurious_att_array = spurious_att_array[eg_array==1]
        for spurious_att in target_spurious_class_list:
            conditional_spurious_att_prior[spurious_att] = conditional_spurious_att_array[conditional_spurious_att_array==spurious_att].shape[0]/conditional_spurious_att_array.shape[0]
        print(f'conditional spurious att prior: {conditional_spurious_att_prior}')

    if query_is_labeled:
        prior_for_metric =spurious_att_prior# conditional_spurious_att_prior
    else:
        prior_for_metric = spurious_att_prior




    query_text_embedding = get_embeddings(query_text, vl_model, processor, normalize).to('cpu').numpy()
    spurious_prompt_embedding = get_embeddings(spurious_prompt, vl_model, processor, normalize).to('cpu').numpy()
    inclusive_candidate_prompt_embedding = get_embeddings(inclusive_candidate_prompt, vl_model, processor, normalize).to('cpu').numpy()
    exclusive_candidate_prompt_embedding = get_embeddings(exclusive_candidate_prompt, vl_model, processor, normalize).to('cpu').numpy()#

    P0 = get_proj_matrix(spurious_prompt_embedding)

    M = get_M(inclusive_candidate_prompt_embedding, S)
    G = lam * M + np.eye(M.shape[0])
    inclusive_P_star = np.matmul(P0, np.linalg.inv(G))

    M = get_M(exclusive_candidate_prompt_embedding, S)#
    G = lam * M + np.eye(M.shape[0])#
    exclusive_P_star = np.matmul(P0, np.linalg.inv(G))#

    P0_embeddings = np.matmul(query_text_embedding, P0.T)
    P0_embeddings = F.normalize(torch.tensor(P0_embeddings), dim=-1).numpy()

    inclusive_P_star_embeddings = np.matmul(query_text_embedding, inclusive_P_star.T)
    inclusive_P_star_embeddings = F.normalize(torch.tensor(inclusive_P_star_embeddings), dim=-1).numpy()

    exclusive_P_star_embeddings = np.matmul(query_text_embedding, exclusive_P_star.T)
    exclusive_P_star_embeddings = F.normalize(torch.tensor(exclusive_P_star_embeddings), dim=-1).numpy()



    rewrite_pair_list = []
    for e_i in range(inclusive_candidate_prompt_embedding.shape[0]):
        for e_j in range(e_i+1, inclusive_candidate_prompt_embedding.shape[0]):
            rewrite_pair_list.append(((inclusive_candidate_prompt_embedding[e_i] - inclusive_candidate_prompt_embedding[e_j])/2).reshape(1,-1) )


    sub_local_embeddings = np.concatenate(rewrite_pair_list)
    print(sub_local_embeddings.shape)
    sub_local_embeddings = np.concatenate([spurious_prompt_embedding, sub_local_embeddings])
    P0_local = get_proj_matrix(sub_local_embeddings)
    #
    P0_local_embeddings = np.matmul(query_text_embedding, P0_local.T)
    P0_local_embeddings = F.normalize(torch.tensor(P0_local_embeddings), dim=-1).numpy()


    M = get_M(exclusive_candidate_prompt_embedding, S)#
    G = lam * M + np.eye(M.shape[0])#
    local_exclusive_P_star = np.matmul(P0_local, np.linalg.inv(G))#


    #######

    if att_to_debias == 'gender':
        num_neighbors = 100#100#500#100 #100 if gender
        K = 500
    else: 
        num_neighbors = 10#10#500#100 #100 if gender
        K = 500
    ref_dataset = reference_embeddings_dataset#target_embeddings_dataset # reference_embeddings_dataset
    spurious_class_list = ref_spurious_class_list#['Female','Male']#[-1,1]#['Female','Male']
    
    target_dist = ref_spurious_att_prior#spurious_att_prior




    legrange_local_proj_embeddings, x_mean, y_means = legrange_text(query_text_embedding, reference_embeddings_dataset, spurious_label=att_to_debias, 
                                            spurious_class_list=att_elements, num_neighbors=num_neighbors, proj_matrix = P0_local, normalize=normalize)





 



    result_dict[query_class]['Vanilla'] = get_metrics(query_text_embedding, query_class, att_to_debias, 
                                             K, prior_for_metric, target_spurious_class_list, name='Vanilla', target_datasets = target_datasets,
                                             QUERY_IS_LABELED=query_is_labeled)


    result_dict[query_class]['P0'] = get_metrics(P0_embeddings, query_class, att_to_debias, K, prior_for_metric, 
                                                      target_spurious_class_list, name='P0', target_datasets = target_datasets, QUERY_IS_LABELED=query_is_labeled)

    result_dict[query_class]['Inclusive_P0'] = get_metrics(inclusive_P_star_embeddings, query_class, att_to_debias, K, prior_for_metric,
                                                  target_spurious_class_list, name='Inclusive_P0', target_datasets = target_datasets, QUERY_IS_LABELED=query_is_labeled)


    print('***'*7)
    print()
    result_dict[query_class]['legrange_local_proj'] = get_metrics(legrange_local_proj_embeddings, query_class, att_to_debias, K, prior_for_metric, 
                                                           target_spurious_class_list, name='legrange_local_proj', target_datasets = target_datasets, 
                                                           QUERY_IS_LABELED=query_is_labeled)


    print('--'*7)
    print()

with open(f'results/{reference_dataset_name}_{target_dataset_name}_{query_type}_{att_to_debias}_fixed.json', 'w') as fp:
    json.dump(result_dict, fp)

In [None]:
METRIC = 'kl_prior'

for i in range(5):
    print(i)
    legrange_local_proj = []
    legrange_reg_proj = []
    exclusive_P0  = []
    Inclusive_P0  = []
    P0  = []
    Vanilla  = []
    for key in result_dict.keys():
        legrange_local_proj.append(result_dict[key]['legrange'][f"fold_{i}"][METRIC])
        legrange_reg_proj.append(result_dict[key]['legrange_reg_local_proj'][f"fold_{i}"][METRIC])
        exclusive_P0.append(result_dict[key]['exclusive_P0'][f"fold_{i}"][METRIC])
        Inclusive_P0.append(result_dict[key]['Inclusive_P0'][f"fold_{i}"][METRIC])
        P0.append(result_dict[key]['P0'][f"fold_{i}"][METRIC])
        Vanilla.append(result_dict[key]['Vanilla'][f"fold_{i}"][METRIC])
    print(f"legrange_local_proj: {np.mean(legrange_local_proj)}")
    print(f"legrange_reg_local_proj: {np.mean(legrange_reg_proj)}")
    print(f"exclusive_P0: {np.mean(exclusive_P0)}")
    print(f"Inclusive_P0: {np.mean(Inclusive_P0)}")
    print(f"P0: {np.mean(P0)}")
    print(f"Base CLIP model: {np.mean(Vanilla)}")
    print()
    print()