In [None]:
#Helps me keep the classes of Bias in Bios straight
professions_dict = {
    0: "accountant",
    1: "architect",
    2: "attorney",
    3: "chiropractor",
    4: "comedian",
    5: "composer",
    6: "dentist",
    7: "dietitian",
    8: "dj",
    9: "filmmaker",
    10: "interior_designer",
    11: "journalist",
    12: "model",
    13: "nurse",
    14: "painter",
    15: "paralegal",
    16: "pastor",
    17: "personal_trainer",
    18: "photographer",
    19: "physician",
    20: "poet",
    21: "professor",
    22: "psychologist",
    23: "rapper",
    24: "software_engineer",
    25: "surgeon",
    26: "teacher",
    27: "yoga_teacher"
}

## Embed data and train a classification head

In [None]:
#from embedding import BertHuggingface
from datasets import load_dataset
import numpy as np
import torch
import os
import pickle
import scipy
import random
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

import matplotlib.pyplot as plt

from concept_helpers.bert_cockatiel import BertCockatielWrapper
from concept_helpers.cockatiel_sub import SubCockatiel
from cockatiel.cockatiel import occlusion_concepts, print_legend, viz_concepts

from experiment_helpers.experiment_helper_functions import *
from experiment_helpers.lm_experiment_helper_functions import *

In [None]:
dataset = load_dataset("LabHC/bias_in_bios")

text_train = dataset['train']['hard_text']
y_train = dataset['train']['profession']
gender_train = dataset['train']['gender']
text_test = dataset['test']['hard_text']
y_test = dataset['test']['profession']
gender_test = dataset['test']['gender']

NUM_CLASSES = np.max(y_test)+1

In [None]:
MODEL_NAME = 'bert-base-uncased'
BATCH_SIZE = 8

bert = BertCockatielWrapper(NUM_CLASSES, model_name=MODEL_NAME, batch_size=BATCH_SIZE)

In [None]:
model_checkpoint = ('models/finetuned_model_%s_relu' % MODEL_NAME)
emb_savefile = ('Experiment3_results/embeddings_finetuned_%s_relu.pickle' % MODEL_NAME)

if os.path.isdir(model_checkpoint):
    print("load model from checkpoint")
    bert.load(model_checkpoint)
else:
    print("train and save model")
    bert.retrain(text_train, y_train, epochs=2)
    bert.save(model_checkpoint)


if os.path.isfile(emb_savefile):
    print("load embeddings")
    with open(emb_savefile, 'rb') as handle:
        embeddings = pickle.load(handle)
    emb_train = embeddings['train']
    emb_test = embeddings['test']

    assert len(emb_test) == len(text_test)
    assert len(emb_train) == len(text_train)
else:
    print("compute and save embeddings")
    emb_test = bert.embed(text_test)
    emb_train = bert.embed(text_train)
    embeddings = {'train': emb_train, 'test': emb_test}
    with open(emb_savefile, 'wb') as handle:
        pickle.dump(embeddings, handle)

pred_file = ('Experiment3_results/test_predictions_%s_relu.npy' % MODEL_NAME)
if os.path.isfile(pred_file):
    y_pred = np.load(pred_file)
else:
    pred = bert.predict(list(text_test))
    y_pred = np.argmax(pred, axis=1)
    np.save(pred_file, y_pred)

assert (emb_test >= 0).all()
assert (emb_train >= 0).all()   

In [None]:
print("Bert F1-macro score: %.2f" % f1_score(y_test, y_pred, average='macro'))

In [None]:
def plot_feature_importance(activations, gender, class_id):
    rf = RandomForestClassifier(max_depth=2, random_state=0)
    rf.fit(activations, gender)

    fig, ax = plt.subplots(figsize=(10, 4))
    cm = plt.get_cmap('tab20')
    plot_bars(ax, rf.feature_importances_, cm, ("RF concept importances for gender prediction (class %i)" % class_id))

def get_pearsoncorr(activations, gender):
    rs = []
    for c in range(activations.shape[1]):
        res = scipy.stats.pearsonr(gender, activations[:,c])
        #print("concept %i got R=%.3f, p=%.3f" % (c, res.statistic, res.pvalue))
        rs.append(res.statistic)
    return rs

def plot_bars(ax, values, colormap, title):
    ax.bar(range(len(values)), values, color=colormap.colors, tick_label=range(len(values)))
    ax.set_title(title, fontsize=18)


def sample_by_class_id(emb, text, y, gender, class_id, multiclass=False):
    if multiclass:
        assert type(class_id) == list
        sample = []
        for cid in class_id:
            sample_c = np.where(y == cid)[0]
            print("for class %i got %i samples" % (cid, sample_c.shape[0]))
            sample.append(sample_c)
        sample = np.hstack(sample)
    else:
        assert type(class_id) == int
        sample = np.where(y == class_id)[0]
    sample_emb = np.array(emb)[sample]
    sample_text = np.array(text)[sample]
    sample_gender = np.array(gender)[sample]
    return sample, sample_emb, sample_text, sample_gender
    

In [None]:
# 19 physician
class_id = 19
multi_class = False

n_concepts = 10
class_lbl = professions_dict[class_id]
#class_lbl = 'multiclass_med'

In [None]:
device = 'cuda'
relu = torch.nn.ReLU()

In [None]:
cur_class_dir = ("Experiment3_results/finetune_%s/" % class_lbl)
if not os.path.isdir(cur_class_dir):
    os.makedirs(cur_class_dir)

if multi_class:
    print("choose multiple classes: ", class_id)
else:
    print("chose class %i (%s)" % (class_id, class_lbl))
sample, sample_emb, sample_text, sample_gender = sample_by_class_id(emb_test, text_test, y_pred, gender_test, class_id=class_id, multiclass=multi_class)

sample_emb = torch.from_numpy(sample_emb)
sample_emb = sample_emb.to(device)
print("got %i samples" % len(sample_emb))

In [None]:
dropout_prob = 0.3
#initialize model mc dropout model
stochastic_model = StochasticModel(bert.model.classifier, dropout_prob=dropout_prob).to(device)

#predict output dist
predictions = predict_with_uncertainty_batched(stochastic_model.to(device), sample_emb, n_iter=100)
predictions = predictions.cpu().numpy()

#get uncertainty 
a, prob_mat = uncertainty_matrices(predictions)
t, e, a = entropy_uncertainty(prob_mat)

# sets threshold to separate uncertainty groups
threshold, _, t_norm = get_threshold(t)

# localizes the uncertainty (maps the uncertaintty values to a binary probability dist) close to threshold get .5 prob far away is close to 0 or 1
loc = UncertaintyWrapperWithSigmoid(threshold)
unc_pred_probs = loc.predict_proba(t_norm)
unc_preds = np.argmax(unc_pred_probs, axis=1)

# filters for UnC classes
l_indices = np.where(unc_preds == 0)[0]
h_indices = np.where(unc_preds == 1)[0]

# create excerpts
l_drift_articles = sample_text[l_indices]
h_drift_articles = sample_text[h_indices]

excerpt_dataset = excerpt_fct(sample_text)

In [None]:
# map predictions, labels and prob/label of Stochastic model to excerpts

def remove_first_occurrence(original_string, substring):
    # Find the first occurrence of the substring
    index = original_string.find(substring)
    
    # If the substring is found, remove it
    if index != -1:
        # Create a new string without the first occurrence of the substring
        new_string = original_string[:index] + original_string[index + len(substring):]
        return new_string
    
    # If the substring is not found, return the original string
    return original_string

excerpt_labels = []
excerpt_pred = []
unc_excerpt_labels = []
unc_excerpt_probs = []

cur_sample_id = 0
cur_sample = sample_text[cur_sample_id]
for i, excerpt in enumerate(excerpt_dataset):
    if excerpt in cur_sample:
        excerpt_labels.append(y_test[sample[cur_sample_id]])
        excerpt_pred.append(y_pred[sample[cur_sample_id]])
        unc_excerpt_probs.append(unc_pred_probs[cur_sample_id][1])
        unc_excerpt_labels.append(unc_preds[cur_sample_id])
        cur_sample = remove_first_occurrence(cur_sample, excerpt)
    else:
        cur_sample_id += 1
        cur_sample = sample_text[cur_sample_id]
        if not excerpt in cur_sample:
            cur_sample_id += 1
            cur_sample = sample_text[cur_sample_id]
        
        if excerpt in cur_sample:
            excerpt_labels.append(y_test[sample[cur_sample_id]])
            excerpt_pred.append(y_pred[sample[cur_sample_id]])
            unc_excerpt_probs.append(unc_pred_probs[cur_sample_id][1])
            unc_excerpt_labels.append(unc_preds[cur_sample_id])
            cur_sample = remove_first_occurrence(cur_sample, excerpt)
        else:
            print("ERROR excerpt neither in current nor two next sample")
            excerpt_labels.append(-1)
            excerpt_pred.append(-1) ###
            unc_excerpt_probs.append(-1)
            unc_excerpt_labels.append(-1)

excerpt_labels = np.asarray(excerpt_labels)
excerpt_pred = np.asarray(excerpt_pred)
unc_excerpt_labels = np.asarray(unc_excerpt_labels)
unc_excerpt_probs = np.asarray(unc_excerpt_probs)

In [None]:
l_indices_excerpt = np.where(unc_excerpt_labels == 0)[0]
h_indices_excerpt = np.where(unc_excerpt_labels == 1)[0]

excerpt_dataset_l = np.array(excerpt_dataset)[l_indices_excerpt].tolist()
excerpt_dataset_h = np.array(excerpt_dataset)[h_indices_excerpt].tolist()

print(len(excerpt_dataset_l), ' low unc excerpts created.')
print(len(excerpt_dataset_h), ' high unc excerpts created.')

In [None]:
print(len(l_indices))
print(len(h_indices))

In [None]:
# save the excerpts
excerpt_file = 'Experiment3_results/finetune_excerpts_%s.pickle' % class_lbl

max_n_samples = 1000
max_n_excerpts = 1000

if os.path.isfile(excerpt_file):
    print("load excerpts and their embeddings")
    with open(excerpt_file, 'rb') as handle:
        savedict = pickle.load(handle)

    excerpt_samples = savedict['excerpts']
    n_excerpts = savedict['n_excerpts']
    excerpt_sample_ids = savedict['excerpt_sample_ids']
    emb_excerpt = savedict['embeddings']

    excerpt_low_ids = np.array(range(n_excerpts))
    excerpt_high_ids = np.array(range(n_excerpts)) + n_excerpts
    excerpt_samples_l = excerpt_samples[:n_excerpts]
    print(len(excerpt_samples_l))
    excerpt_samples_h = excerpt_samples[n_excerpts:]
    print(len(excerpt_samples_h))

    emb_excerpt_l = emb_excerpt[excerpt_low_ids,:]
    emb_excerpt_h = emb_excerpt[excerpt_high_ids,:]

    n_samples = savedict['n_samples']

    sample_ids_l = savedict['sample_ids_l']
    sample_ids_h = savedict['sample_ids_h']
    
    sample_emb_for_imp_l = sample_emb[sample_ids_l]
    sample_emb_for_imp_h = sample_emb[sample_ids_h]

else:
    print("sample excerpts")
    n_excerpts = np.min([len(excerpt_dataset_l), len(excerpt_dataset_h), max_n_excerpts])
    excerpt_low_ids = np.array(range(n_excerpts))
    excerpt_high_ids = np.array(range(n_excerpts)) + n_excerpts
    
    excerpt_samples_ids_l = random.sample(excerpt_low_ids.tolist(), n_excerpts)
    excerpt_samples_ids_h = random.sample(excerpt_high_ids.tolist(), n_excerpts)

    excerpt_samples_l = np.asarray(excerpt_dataset)[excerpt_samples_ids_l].tolist()
    excerpt_samples_h = np.asarray(excerpt_dataset)[excerpt_samples_ids_h].tolist()

    excerpt_sample_ids = excerpt_samples_ids_l + excerpt_samples_ids_h
    excerpt_samples = excerpt_samples_l + excerpt_samples_h
    
    print("embed %i excerpts..." % (len(excerpt_samples)))
    emb_excerpt = bert.embed(excerpt_samples)
    emb_excerpt = relu(torch.from_numpy(emb_excerpt)).detach().numpy()
    
    emb_excerpt_l = emb_excerpt[excerpt_low_ids,:]
    emb_excerpt_h = emb_excerpt[excerpt_high_ids,:]

    n_samples = np.min([len(l_indices),len(h_indices), max_n_samples])

    sample_ids_l = random.sample(l_indices.tolist(), n_samples)
    sample_ids_h = random.sample(h_indices.tolist(), n_samples)
    
    sample_emb_for_imp_l = sample_emb[sample_ids_l]
    sample_emb_for_imp_h = sample_emb[sample_ids_h]
    
    savedict = {'excerpt_sample_ids': excerpt_sample_ids, 'excerpts': excerpt_samples, 'n_excerpts': n_excerpts, 'embeddings': emb_excerpt, 
                'sample_ids_l': sample_ids_l, 'sample_ids_h': sample_ids_h, 'n_samples': n_samples}

    #savedict = {'excerpt_sample_ids': excerpt_sample_ids, 'excerpts': excerpt_samples, 'n_samples': n_samples, 'embeddings': emb_excerpt}
    with open(excerpt_file, 'wb') as handle:
        pickle.dump(savedict, handle)

# need ndarray and torch tensor
sample_emb_np = sample_emb.cpu().detach().numpy()

In [None]:
n_concepts = 10
savefile = 'Experiment3_results/finetune_results_%s_%s.pickle' % (class_lbl, n_concepts)

len_data = len(emb_excerpt)
if n_samples < 100:
    len_samples = len(sample_text)
else:
    len_samples = len_data//10
print(len_data, len_samples)
if os.path.isfile(savefile):
    with open(savefile, 'rb') as handle:
        results = pickle.load(handle)
    
    low_unc_importances = results['imp_low']
    high_unc_importances = results['imp_high']
    segments = results['segments']
    u_segments = results['u_segments']
    cockatiel_explainer = results['cockatiel']
    factorization = results['factorization']
    global_importance = results['global_importance']
    _pa = results['_pa']


else:
    print("compute NMF and concept importances..")
    
    # NMF for current class

    with torch.no_grad():
        cockatiel_explainer = SubCockatiel(bert, bert.tokenizer, n_concepts, 64, device)
        segments, u_segments, factorization, global_importance ,_pa = cockatiel_explainer.extract_concepts(
                                                        excerpt_dataset[:len_data],
                                                        np.asarray(emb_excerpt[:len_data]), 
                                                        sample_text[:(len_samples)],
                                                        sample_emb_np[:(len_samples)], 
                                                        0, limit_sobol = 1_000)

    #global importances over whole class, low and high uncertainty samples
    print("compute high uncertainty concept importance...")
    high_unc_importances, high_sens_dict = estimate_importance(cockatiel_explainer, stochastic_model, sample_emb_for_imp_h, 1, factorization.components_)
    print("compute low uncertainty concept importance...")
    low_unc_importances, low_sens_dict = estimate_importance(cockatiel_explainer, stochastic_model, sample_emb_for_imp_l, 0, factorization.components_)
    #importances, sens_dict = estimate_importance(cockatiel_explainer, stochastic_model, sample_emb,1, factorization.components_)

    results = {'imp_low': low_unc_importances, 'imp_high': high_unc_importances, 'segments': segments, 'u_segments': u_segments, 'cockatiel': cockatiel_explainer, 'factorization': factorization, 'global_importance': global_importance , '_pa': _pa}
    
    with open(savefile, 'wb') as handle:
        pickle.dump(results, handle)

In [None]:
# compute concept activations
activations_all = activation_transform(torch.from_numpy(emb_test), factorization.components_)
activations_class = activation_transform(sample_emb, factorization.components_)

In [None]:
# plot correlation and feature importance for gender
plot_feature_importance(activations_class, sample_gender, class_id)
plt.savefig("%s/gender_concept_importance.png" % cur_class_dir)

rs_all = get_pearsoncorr(activations_all, gender_test)
rs_class = get_pearsoncorr(activations_class, sample_gender)

fig, axes = plt.subplots(1, 2, figsize=(22, 5))
cm = plt.get_cmap('tab20')

plot_bars(axes[0], rs_all, cm, "Pearson correlation of concept activations with gender labels (all classes)")
plot_bars(axes[1], rs_class, cm, "Pearson correlation of concept activations with gender labels (%s)" % class_lbl)
plt.savefig("%s/gender_corr.png" % cur_class_dir)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
cm = plt.get_cmap('tab20')

plot_bars(ax, rs_all, cm, "Pearson correlation of concept activations with gender labels")
plt.ylabel('Pearson R', fontsize=14)
plt.xlabel('Concepts', fontsize=14)
plt.tight_layout()
plt.savefig("%s/gender_corr_.png" % cur_class_dir)
plt.show()

In [None]:
# feature importance for low/ high uncertainty
fig, axes = plt.subplots(1, 2, figsize=(16, 5))
cm = plt.get_cmap('tab20')

plot_bars(axes[0], low_unc_importances[0], cm, "Low uncertainty samples (%s)" % class_lbl)
plot_bars(axes[1], high_unc_importances[0], cm, "High uncertainty samples (%s)" % class_lbl)

axes[0].set_ylabel('Global concept importance', fontsize=14)
axes[0].set_xlabel('Concepts', fontsize=14)
axes[1].set_xlabel('Concepts', fontsize=14)
plt.tight_layout()
plt.savefig("%s/unc_importance.png" % cur_class_dir)

In [None]:
print("low unc female ratio: ", np.sum(sample_gender[l_indices])/len(l_indices))
print("high unc female ratio: ", np.sum(sample_gender[h_indices])/len(h_indices))

## Explain concepts by token attribution

In [None]:
perturbated_activations = activation_transform(sample_emb, factorization.components_)
perturbated_activations = perturbated_activations @ factorization.components_

err = np.linalg.norm(sample_emb.cpu().detach().numpy() - perturbated_activations, 'fro')
print("reconstruction error", err)

In [None]:
concepts_sel = [6]

perturbated_activations = activation_transform(sample_emb, factorization.components_)
perturbated_activations_rem = activation_transform(sample_emb, factorization.components_)
for concept_id in concepts_sel:
    print("set concept %i to its mean=%.3f" % (concept_id, np.mean(perturbated_activations[:,concept_id])))
    perturbated_activations_rem[:,concept_id] = np.mean(perturbated_activations_rem[:,concept_id])
    
perturbated_activations = perturbated_activations @ factorization.components_
perturbated_activations_rem = perturbated_activations_rem @ factorization.components_

err = np.linalg.norm(sample_emb.cpu().detach().numpy() - perturbated_activations, 'fro')
print("reconstruction error: ", err)
err = np.linalg.norm(sample_emb.cpu().detach().numpy() - perturbated_activations_rem, 'fro')
print("reconstruction error + concept change: ", err)


In [None]:
def predict_batched(f_model, inputs, device="cuda"):
    inputs = inputs.to(device)
    data_loader = torch.utils.data.DataLoader(inputs, batch_size=64)

    preds = []
    for inputs in data_loader:
        inputs = inputs.to(device)
        with torch.no_grad():
            batch_preds = f_model(inputs)
            # print(batch_preds)
        if len(batch_preds.size()) == 1:
            preds.append(batch_preds.reshape((1,batch_preds.shape[0])))
        else:
            preds.append(batch_preds)

    # Stack predictions across iterations
    return torch.vstack(preds)#, dim=0)

pred = predict_batched(bert.model.classifier, torch.from_numpy(perturbated_activations))
pred = pred.cpu().detach().numpy()
y_pred_perturbed = np.argmax(pred, axis=1)


pred = predict_batched(bert.model.classifier, torch.from_numpy(perturbated_activations_rem))
pred = pred.cpu().detach().numpy()
y_pred_perturbed_rem = np.argmax(pred, axis=1)

In [None]:
print("%i predictions changed by reconstruction" % (np.sum(y_pred[sample] != y_pred_perturbed)))
print("%i predictions changed by concept removal" % (np.sum(y_pred[sample] != y_pred_perturbed_rem)))

In [None]:
y_true_change = np.array(y_test)[sample][y_pred[sample] != y_pred_perturbed]
y_pred_change = y_pred_perturbed[y_pred[sample] != y_pred_perturbed]
gender_change = sample_gender[y_pred[sample] != y_pred_perturbed]
is_low = [1 if idx in l_indices else 0 for idx in range(len(sample))]
is_low_change = np.array(is_low)[y_pred[sample] != y_pred_perturbed]
l_indices_change = np.where(is_low_change == 1)[0]
h_indices_change = np.where(is_low_change == 0)[0]

In [None]:
y_true_change_rem = np.array(y_test)[sample][y_pred[sample] != y_pred_perturbed_rem]
y_pred_change_rem = y_pred_perturbed_rem[y_pred[sample] != y_pred_perturbed_rem]
gender_change_rem = sample_gender[y_pred[sample] != y_pred_perturbed_rem]
is_low_rem = [1 if idx in l_indices else 0 for idx in range(len(sample))]
is_low_change_rem = np.array(is_low_rem)[y_pred[sample] != y_pred_perturbed_rem]
l_indices_change_rem = np.where(is_low_change_rem == 1)[0]
h_indices_change_rem = np.where(is_low_change_rem == 0)[0]

In [None]:
print("%i predictions were corrected" % np.sum(y_true_change == y_pred_change))
print("%i predictions are still wrong" % np.sum((y_true_change != y_pred_change) & (y_true_change != class_id)))
print("%i predictions turned wrong" % np.sum((y_true_change != y_pred_change) & (y_true_change == class_id)))

print("percentage of high uncertainty samples among changed predictions: %.3f" % (len(h_indices_change)/(len(h_indices_change)+len(l_indices_change))))

In [None]:
print("%i predictions were corrected" % np.sum(y_true_change_rem == y_pred_change_rem))
print("%i predictions are still wrong" % np.sum((y_true_change_rem != y_pred_change_rem) & (y_true_change_rem != class_id)))
print("%i predictions turned wrong" % np.sum((y_true_change_rem != y_pred_change_rem) & (y_true_change_rem == class_id)))

print("percentage of high uncertainty samples among changed predictions: %.3f" % (len(h_indices_change)/(len(h_indices_change)+len(l_indices_change))))

In [None]:
print("high uncertainty:")
print("%i predictions were corrected" % np.sum((y_true_change == y_pred_change)[h_indices_change]))
print("%i predictions are still wrong" % np.sum(((y_true_change != y_pred_change) & (y_true_change != class_id))[h_indices_change]))
print("%i predictions turned wrong" % np.sum(((y_true_change != y_pred_change) & (y_true_change == class_id))[h_indices_change]))

In [None]:
print("low uncertainty:")
print("%i predictions were corrected" % np.sum((y_true_change == y_pred_change)[l_indices_change]))
print("%i predictions are still wrong" % np.sum(((y_true_change != y_pred_change) & (y_true_change != class_id))[l_indices_change]))
print("%i predictions turned wrong" % np.sum(((y_true_change != y_pred_change) & (y_true_change == class_id))[l_indices_change]))

In [None]:
print("female ratio among correct samples: ", np.mean(gender_change[y_true_change == y_pred_change]))
print("female ratio among wrong-turned samples: ", np.mean(gender_change[(y_true_change != y_pred_change) & (y_true_change == class_id)]))

In [None]:
print("female ratio of class %i (true label): %.3f" % (class_id, np.mean(np.array(gender_test)[np.array(y_test) == class_id])))
print("female ratio of class %i predictions: %.3f" % (class_id, np.mean(np.array(gender_test)[y_pred == class_id])))

In [None]:
print("female ratio among prof->physician (bc of concept %i: %.3f" % (concept_id, np.mean(gender_change[(y_true_change == y_pred_change) & (y_true_change == 21)])))
print("female ratio among surgeon->physician (bc of concept %i: %.3f" % (concept_id, np.mean(gender_change[(y_true_change == y_pred_change) & (y_true_change == 25)])))
print("female ratio among chiro->physician (bc of concept %i: %.3f" % (concept_id, np.mean(gender_change[(y_true_change == y_pred_change) & (y_true_change == 3)])))
print("female ratio among teacher->physician (bc of concept %i: %.3f" % (concept_id, np.mean(gender_change[(y_true_change == y_pred_change) & (y_true_change == 26)])))


In [None]:
for cid in [21, 25, 3, 26]:
    print("gender ratio of class %i (%s): %.3f" % (cid, professions_dict[cid], np.mean(np.array(gender_test)[np.array(y_test) == cid])))
    print("FP of class %i: %i" % (cid, np.sum((y_true_change == y_pred_change) & (y_true_change == cid))))

In [None]:
np.unique(y_true_change[(y_true_change == y_pred_change)], return_counts=True)

In [None]:
print("%i predictions changed" % (np.sum(y_pred[sample] != y_pred_perturbed)))

In [None]:
g_change = sample_gender[(y_pred[sample] != y_pred_perturbed)]
print("female ratio of changed samples: ", np.sum(g_change)/len(g_change))

g_prev_error = sample_gender[(y_pred[sample] != np.array(y_test)[sample])]
g_perturbed_error = sample_gender[(y_pred_perturbed != np.array(y_test)[sample])]
print("female ratio of previous errors: ", np.sum(g_prev_error)/len(g_prev_error))
print("female ratio of errors after concept removal: ", np.sum(g_perturbed_error)/len(g_perturbed_error))

In [None]:
y_test_neg_sample = [y_test[i] for i in range(len(y_test)) if i not in sample]
y_pred_neg_sample = [y_pred[i] for i in range(len(y_pred)) if i not in sample]
gender_neg_sample = [gender_test[i] for i in range(len(gender_test)) if i not in sample]

y_test_all_pert_order = y_test_neg_sample + np.array(y_test)[sample].tolist()
y_pred_all_pert_order = y_pred_neg_sample + y_pred_perturbed.tolist()
y_pred_all_pert_order_rem = y_pred_neg_sample + y_pred_perturbed_rem.tolist()
gender_all_pert_order = gender_neg_sample + np.array(gender_test)[sample].tolist()


In [None]:
# note: changed to FP because that makes more sense with cockatiel sampling
def equalized_odds(y_true, y_pred, gender):
    if type(y_true) == list:
        y_true = np.array(y_true)
    if type(y_pred) == list:
        y_pred = np.array(y_pred)
    if type(gender) == list:
        gender = np.array(gender)
        
    n_classes = max(y_true)+1
    scores = []
    tp_count = []
    for c in range(n_classes):
        tp_mask = (y_pred[y_true == c] == c)
        gender_c = gender[y_true == c]

        tp_mask_0 = tp_mask[gender_c == 0]
        tp_mask_1 = tp_mask[gender_c == 1]

        n_samples_0 = tp_mask_0.shape[0]
        n_samples_1 = tp_mask_1.shape[0]
            
        if n_samples_0 == 0 or n_samples_1 == 0:
            scores.append(0)
        else:
            tp_0 = np.sum(tp_mask_0)/n_samples_0
            tp_1 = np.sum(tp_mask_1)/n_samples_1
            scores.append(tp_1 - tp_0)
        tp_count.append(np.sum(tp_mask_0) + np.sum(tp_mask_1))

    return scores, tp_count

eo_base, tp_base = equalized_odds(y_test, y_pred, gender_test)
eo_sample_recon, tp_recon = equalized_odds(y_test_all_pert_order,y_pred_all_pert_order, gender_all_pert_order)
eo_sample_pert, tp_pert = equalized_odds(y_test_all_pert_order, y_pred_all_pert_order_rem, gender_all_pert_order)

In [None]:
print("baseline: ", np.mean(eo_base))
print("after reconstruction: ", np.mean(eo_sample_recon))
print("after concept removal: ", np.mean(eo_sample_pert))

In [None]:
class_names = list(professions_dict.values())

eo_diff = np.array(eo_base)-np.array(eo_sample_pert)
eo_base = np.array(eo_base)[eo_diff != 0]
eo_sample_pert = np.array(eo_sample_pert)[eo_diff != 0]
eo_sample_recon = np.array(eo_sample_recon)[eo_diff != 0]

eo_diff_sel = eo_diff[eo_diff != 0]
class_names_sel = np.array(class_names)[eo_diff != 0]

In [None]:
np.abs(eo_sample_recon)-np.abs(eo_sample_pert)

In [None]:
np.mean(np.abs(eo_sample_recon)-np.abs(eo_sample_pert))

In [None]:
x = np.arange(len(eo_base))

# Set the width of the bars
bar_width = 0.2

# Erstellen einer einzigen Figur und Achse
fig, ax = plt.subplots(figsize=(8,5))

# Plotten des ersten Balkensatzes
ax.bar(x - bar_width, eo_base, width=bar_width, label='Baseline', color='b')

# Plotten des zweiten Balkensatzes
ax.bar(x, eo_sample_recon, width=bar_width, label='After concept reconstruction', color='r')

# Plotten des dritten Balkensatzes
ax.bar(x + bar_width, eo_sample_pert, width=bar_width, label='After concept perturbation', color='g')

# Set the x-ticks, labels, and rotation
ax.set_xticks(x)
ax.set_xticklabels(class_names_sel, rotation=90)
ax.set_xlabel('Classes', fontsize=14)
ax.set_ylabel('Equalized Odds', fontsize=14)
ax.set_title('Equalized Odds Comparison', fontsize=18)
ax.legend()

# Show the plot
plt.tight_layout()
plt.savefig('%s/eq_odds.png' % cur_class_dir)
plt.show()

In [None]:
print("baseline: ", np.mean(np.abs(eo_base)))
print("after reconstruction: ", np.mean(np.abs(eo_sample_recon)))
print("after concept removal: ", np.mean(np.abs(eo_sample_pert)))

In [None]:
y_pred_all_pert_order_best = y_pred_neg_sample + np.array(y_test)[sample].tolist()
eo_sample_best, tp_best = equalized_odds(y_test_all_pert_order, y_pred_all_pert_order_best, gender_all_pert_order)

eo_sample_best = np.array(eo_sample_best)[eo_diff != 0]
print("baseline: ", np.mean(np.abs(eo_base)))
print("best possible EO from this intervention: ", np.mean(np.abs(eo_sample_best)))

In [None]:
np.mean(np.abs(eo_sample_best)) - np.mean(np.abs(eo_base))

In [None]:
# visualize samples that changed due to feature removal

corrected_fp_mask = ((y_pred[sample] != y_pred_perturbed_rem) & (y_pred[sample] == y_pred_perturbed) & (np.array(y_test)[sample] == y_pred_perturbed_rem))
new_fn_mask = ((y_pred[sample] == np.array(y_test)[sample]) & (y_pred_perturbed_rem != y_pred_perturbed) & (np.array(y_test)[sample] != y_pred_perturbed_rem))

sent_corrected_fp = sample_text[corrected_fp_mask]
sent_new_fn = sample_text[new_fn_mask]

In [None]:
print("professors classified as physician: ", np.sum((y_pred == 19) & (np.array(y_test) == 21)))
print("professors in the test data: ", np.sum(np.array(y_test) == 21))

## Explain concepts by token attribution

In [None]:
!pip install imgkit

In [None]:
import imgkit
from typing import List, Optional
from IPython.core.display import display, HTML
import nltk
from nltk.tokenize import word_tokenize
from cockatiel.cockatiel.utils import extract_clauses
nltk.download('punkt')

phi_thresh = 0.05

def viz_concepts_to_img(
        output_image_path,
        text,
        explanation,
        colors,
        ignore_words: Optional[List[str]] = None,
        extract_fct: str = "clause"
):
    """
    Generates the visualization for COCKATIEL's explanations.

    Parameters
    ----------
    text
        A string with the text we wish to explain.
    explanation
        An array that corresponds to the output of the occlusion function.
    ignore_words
        A list of strings to ignore when applying occlusion.
    extract_fct
        A string indicating whether at which level we wish to explain: "word", "clause" or "sentence".
    colors
        A dictionary with the colors for each label
    """
    try:
        text = text.decode('utf-8')
    except:
        text = str(text)

    if extract_fct == "clause":
        words = extract_clauses(text, clause_type=None)
    else:
        words = word_tokenize(text)

    l_phi = np.array(explanation)

    phi_html = []

    p = 0  # pointer to get current color for the words (it does not color words that have no phi)
    for i in range(len(words)):
        if words[i] not in ignore_words:
            k = 0
            for j in range(len(l_phi)):
                if l_phi[k][p] < l_phi[j][p]:
                    k = j

            if l_phi[k][p] > phi_thresh:
                phi_html.append(f'<span style="background-color: {colors[k]} {l_phi[k][p]}); padding: 1px 5px; border: solid 3px ; border-color: {colors[k]} 1); #EFEFEF">{words[i]}</span>')
                p += 1
            else:
                phi_html.append(f'<span style="background-color: rgba(233,30,99,0);  padding: 1px 5px; border: solid 3px ; border-color:  rgba(233,30,99,0); #EFEFEF">{words[i]}</span>')
                p += 1
        else:
            phi_html.append(f'<span style="background-color: rgba(233,30,99,0);  padding: 1px 5px; border: solid 3px ; border-color:  rgba(233,30,99,0); #EFEFEF">{words[i]}</span>')
    #display(HTML("<div style='display: flex; width: 400px; flex-wrap: wrap'>" +  " ".join(phi_html) + " </div>" ))
    #display(HTML('<br><br>'))

    complete_html = f"<div style='display: flex; width: 400px; flex-wrap: wrap'>{' '.join(phi_html)}</div>"
    imgkit.from_string(complete_html, output_image_path)

def print_legend_img(output_image_path, colors, label_to_criterion):
    """
    Prints the legend for the plot in different colors.

    Parameters
    ----------
    colors
        A dictionary with the colors for each label.
    label_to_criterion
        A dictionary with the text to put on each label.
    """
    html = []
    for label_id in label_to_criterion.keys():
        html.append(f'<span style="background-color: {colors[label_id]} 0.5); padding: 1px 5px; border: solid 3px ; border-color: {colors[label_id]} 1); #EFEFEF">{label_to_criterion[label_id]} </span>')
    #display(HTML("<div style='display: flex; width: 400px; flex-wrap: wrap'>" +  " ".join(html) + " </div>" ))
    #display(HTML('<br><br>'))
    
    complete_html = f"<div style='display: flex; width: 400px; flex-wrap: wrap'>{' '.join(html)}</div>"
    imgkit.from_string(complete_html, output_image_path)

In [None]:
# copied from cockatiel (to adjust the threshold of phi)

def viz_concepts(
        text,
        explanation,
        colors,
        ignore_words: Optional[List[str]] = None,
        extract_fct: str = "clause"
):
    """
    Generates the visualization for COCKATIEL's explanations.

    Parameters
    ----------
    text
        A string with the text we wish to explain.
    explanation
        An array that corresponds to the output of the occlusion function.
    ignore_words
        A list of strings to ignore when applying occlusion.
    extract_fct
        A string indicating whether at which level we wish to explain: "word", "clause" or "sentence".
    colors
        A dictionary with the colors for each label
    """
    try:
        text = text.decode('utf-8')
    except:
        text = str(text)

    if extract_fct == "clause":
        words = extract_clauses(text, clause_type=None)
    else:
        words = word_tokenize(text)

    l_phi = np.array(explanation)

    phi_html = []

    p = 0  # pointer to get current color for the words (it does not color words that have no phi)
    for i in range(len(words)):
        if words[i] not in ignore_words:
            k = 0
            for j in range(len(l_phi)):
                if l_phi[k][p] < l_phi[j][p]:
                    k = j

            if l_phi[k][p] > phi_thresh:
                phi_html.append(f'<span style="background-color: {colors[k]} {l_phi[k][p]}); padding: 1px 5px; border: solid 3px ; border-color: {colors[k]} 1); #EFEFEF">{words[i]}</span>')
                p += 1
            else:
                phi_html.append(f'<span style="background-color: rgba(233,30,99,0);  padding: 1px 5px; border: solid 3px ; border-color:  rgba(233,30,99,0); #EFEFEF">{words[i]}</span>')
                p += 1
        else:
            phi_html.append(f'<span style="background-color: rgba(233,30,99,0);  padding: 1px 5px; border: solid 3px ; border-color:  rgba(233,30,99,0); #EFEFEF">{words[i]}</span>')
    display(HTML("<div style='display: flex; width: 400px; flex-wrap: wrap'>" +  " ".join(phi_html) + " </div>" ))
    display(HTML('<br><br>'))

In [None]:

colors = {
    0: "rgba(9, 221, 55, ",   # green
    1: "rgba(9, 221, 161, ",  # turquoise
    2: "rgba(9, 175, 221, ",  # blue
    3: "rgba(221, 9, 34, ",   # red
    4: "rgba(221, 9, 140, ",  # pink
    5: "rgba(221, 90, 9, ",   # orange
    6: "rgba(221, 9, 221, ",  # bright red
    7: "rgba(221, 221, 9, ",  # yellow
    8: "rgba(9, 55, 221, ",   # 
    9: "rgba(9, 221, 9, ",    # lime
}

tokenizer = bert.tokenizer

In [None]:
# set to False to display the token attribution instead of saving images
SAVE_IMG = True

### false-positives with high uncertainty (fixed by concept removal)

In [None]:
# high uncertainty
i = 0
n = len(sent_corrected_fp)
m_pos = 4
l_concept_id_l = [2,9,4,6] #[2,3,7,9,4,6]

label_to_criterion = {idx: {'id': l_concept_id_l[idx], 'label': "Positive label: concept %i" % l_concept_id_l[idx]} for idx in range(m_pos)}


colors_to_draw = {idx: colors[idx] for idx in range(m_pos)}
label_criterion_for_legend = {idx: label_to_criterion[idx]['label'] for idx in range(m_pos)}
l_concept_ids = [idx for idx in l_concept_id_l]
next_idx = m_pos

#sanity check:
if len(label_criterion_for_legend.keys()) != len(colors_to_draw.keys()) or len(label_criterion_for_legend.keys()) > (m_pos):
    print("Error: check that you have the correct number of colors and labels in your dictionaries to cover \
    the number of concepts being looked at")

if SAVE_IMG:
    print_legend_img(('%s/legend.png' % cur_class_dir), colors_to_draw, label_criterion_for_legend)
else:
    print_legend(colors_to_draw, label_criterion_for_legend)


for sentence in sent_corrected_fp:
    if i%n == 0 :
        print("\n")
        print("samples that are falsely predicted as physician due to concept 6:")
        print("\n")
    phi = occlusion_concepts(sentence, bert, tokenizer, factorization, 
                           l_concept_ids, ignore_words = [], two_labels = False, device = device)
    phi /= np.max(np.abs(phi)) + 1e-5
    
    if SAVE_IMG:
        img_name = '%s/sent_corrected_fp_%i.png' % (cur_class_dir, i)
        viz_concepts_to_img(img_name, sentence, phi, colors_to_draw, ignore_words = [])
    else:
        viz_concepts(sentence, phi, colors_to_draw, ignore_words = [])
    i += 1


### true-positives with high-uncertainty (turned false-negative after concept removal)

In [None]:
i = 0
n = len(sent_new_fn)
for sentence in sent_new_fn:
    if i%n == 0 :
        print("\n")
        print("samples that are falsely predicted as physician due to concept 6:")
        print("\n")
    phi = occlusion_concepts(sentence, bert, tokenizer, factorization, 
                           l_concept_ids, ignore_words = [], two_labels = False, device = device)
    phi /= np.max(np.abs(phi)) + 1e-5

    if SAVE_IMG:
        img_name = '%s/sent_new_fn_%i.png' % (cur_class_dir, i)
        viz_concepts_to_img(img_name, sentence, phi, colors_to_draw, ignore_words = [])
    else:
        viz_concepts(sentence, phi, colors_to_draw, ignore_words = [])
  
    i += 1