# Calculate accuracy for a global threshold using the Balance Faces in the Wild (BFW) dataset.

Uses the data in `data/bfw-datatable.pkl` to determine the NN. Saves the summary to `results/bfw-stats.csv`.

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [4]:
import numpy as np
import pandas as pd
import sys

sys.path.append('../')
from facebias.utils import find_best_threshold

from facebias.iotools import load_bfw_datatable
import matplotlib.pyplot as plt


In [6]:
bfw_version = "0.1.5"
dir_data = "../../data/bfw/"

dir_features = f"{dir_data}features/sphereface/"
dir_meta = f"{dir_data}meta/"
f_datatable = f"{dir_meta}bfw-v{bfw_version}-datatable.pkl"
f_threshold = f"{dir_meta}thresholds.pkl"

thresholds_arr = np.linspace(0.18, 0.4, 500)
global_threshold = []
data = load_bfw_datatable(f_datatable, default_score_col="sphereface")

data["yp0"] = 0

folds = data.fold.unique()

In [7]:
for fold in folds:
    ids = data.fold != fold
    threshold, score = find_best_threshold(
        thresholds_arr, data.loc[ids, ["label", "score"]]
    )
    print(f"Fold {fold}: with t_g={threshold}, acc={score}")
    data.loc[ids, "yp0"] = (data["score"] >= threshold).astype(int)
    global_threshold.append(threshold)


Fold 1: with t_g=0.2589178356713427, acc=0.962591485274422
Fold 2: with t_g=0.2545090180360722, acc=0.9612347208770041
Fold 3: with t_g=0.2606813627254509, acc=0.9621221858861
Fold 4: with t_g=0.2589178356713427, acc=0.9612281233426775
Fold 5: with t_g=0.2606813627254509, acc=0.9625876439756565


In [8]:
data['iscorrect'] = (data["yp0"] == data['label']).astype(int)

In [10]:
pd.DataFrame(data.groupby('att1').sum()['iscorrect']/data.groupby('att1').count()['iscorrect'])

Unnamed: 0_level_0,iscorrect
att1,Unnamed: 1_level_1
asian_females,0.941492
asian_males,0.951706
black_females,0.967425
black_males,0.960341
indian_females,0.958432
indian_males,0.962255
white_females,0.972814
white_males,0.981124


In [7]:
accuracy = (data.groupby('att1').sum()['iscorrect']/data.groupby('att1').count()['iscorrect']).mean()
std = (data.groupby('att1').sum()['iscorrect']/data.groupby('att1').count()['iscorrect']).std()

In [8]:
print(f"Accuracy:{accuracy}\nSTD:{std}")

Accuracy:0.9619485648779713
STD:0.01227155889146341
