In [24]:
import pandas as pd

In [45]:
from typing import List, Any, Dict

import numpy as np
from sklearn.metrics import f1_score, confusion_matrix

import copy


def metrics(
    preds: List[str],
    labels: List[str],
    demographics: List[List[str]],
    overall_demographics: List[str],
) -> Dict[str, Any]:
    """Returns a dictionary of overall macro f1 score,
    recall, specificity, and macro f1 for each group
    and largest gap in recall for each label

    Code heavily modified from https://github.com/pocaguirre/jiant/blob/master/jiant/ext/fairness/DF_training.py

    :param preds: list of predictions from model
    :type preds: List[str]
    :param labels: list of groundtruth labels
    :type labels: List[str]
    :param dataset: name of dataset
    :type dataset: str
    :param demographics: list of lists of demographic groups associated with labels
    :type demographics: List[List[str]]
    :param overall_demographics: Demographics to focus on in the output
    :type overall_demographics: List[str]
    :return: dictionary of overall macro f1 score, recall, specificity, and macro f1 for each group
    and largest gap in recall for each label
    :rtype: Dict[str, Any]
    """

    scores = {"recall": {}, "specificity": {}, "score": {}}

    # create subgroups to focus on

    # create list of all labels

    # map labels to numbers to make it easier for sklearn calculations
    label_map = {"toxic" : 1, "non-toxic" : 0}

    # map the labels lists to dummy labels
    dummy_labels = [label_map[x] for x in labels]

    dummy_preds = np.array(preds)
    dummy_labels = np.array(dummy_labels)

    # remove predictions that have demographics not in the set
    # mostly for hatexplain which has multiple demographics per label
    demographic_index = [
        i
        for i, item in enumerate(demographics)
        if len(set(overall_demographics).intersection(set(item))) != 0
    ]

    demographics_filtered = copy.deepcopy([demographics[i] for i in demographic_index])

    dummy_preds = dummy_preds[demographic_index]
    dummy_labels = dummy_labels[demographic_index]

    print(len(dummy_preds))

    # get total score
    scores["total_score"] = f1_score(
        dummy_labels, dummy_preds, average="macro", labels=list(label_map.values())
    )

    for dem in overall_demographics:

        # filter out items that do not have the specified demographic
        index = [i for i, item in enumerate(demographics_filtered) if dem in item]

        # calculate f1, recall and specifity for those items
        cnf_matrix = confusion_matrix(
            dummy_labels[index], dummy_preds[index], labels=list(label_map.values())
        )

        fn = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)
        tp = np.diag(cnf_matrix)
        fp = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)
        tn = cnf_matrix.sum() - (fp + fn + tp)

        fn = fn.astype(float)
        tp = tp.astype(float)
        fp = fp.astype(float)
        tn = tn.astype(float)

        score = f1_score(
            dummy_labels[index],
            dummy_preds[index],
            average="macro",
            labels=list(label_map.values()),
        )

        recall = tp / (tp + fn)
        specificity = tn / (tn + fp)

        scores["recall"][dem] = recall
        scores["specificity"][dem] = specificity
        scores["score"][dem] = score

    gaps = []

    # calculate all the TPR gaps for every possible combination of demographics
    for group1 in scores["recall"]:
        for group2 in scores["recall"]:
            gap = scores["recall"][group1] - scores["recall"][group2]

            gap = np.nan_to_num(gap)

            one_minus_gap = 1 - gap
            gaps.append([group1, group2, gap, one_minus_gap])

    # get the maximum TPR gap per class
    max_gaps = dict()
    for i, label in enumerate(list(label_map.keys())):
        gaps = sorted(gaps, key=lambda x: x[2][i], reverse=True)

        max_gaps[label] = copy.deepcopy(gaps[0])
        max_gaps[label][2] = max_gaps[label][2][i]
        max_gaps[label][3] = max_gaps[label][3][i]

    scores["max_gaps"] = max_gaps

    return scores


In [46]:
df = pd.read_csv("/data/caguirre/MultitaskFairness/preds_hatexplain_race_all.csv")

In [47]:
demographics = [[x] for x in df['race'].tolist()]

In [48]:
bernice_predictions = df['bernice'].tolist()
bertweet_predictions = df['bertweet'].tolist()

In [49]:
overall_demographics =  [
            "African",
            "Arab",
            "Asian",
            "Hispanic",
            "Caucasian",
]

In [50]:
dataset = "hatexplain-race"

In [51]:
labels = df['label'].tolist()

label_map = {"toxic" : 1, "non-toxic" : 0}

# map the labels lists to dummy labels
dummy_labels = [label_map[x] for x in labels]

In [None]:
models = ['bernice', 'bertweet']
predictions = [bernice_predictions, bertweet_predictions]

results = []

for i in range(len(models)):
    model = models[i]
    prediction = predictions[i]

    performance = metrics(prediction, labels, demographics, overall_demographics)


    result = [
            model,
            'N/A',
            'supervised',
            performance["total_score"],
    ]

    group_results = performance["score"]

    gaps = performance["max_gaps"]

    for group_result in group_results:
        result.append({group_result: group_results[group_result]})

    gaps = dict(sorted(gaps.items(), key=lambda item: item[1][3]))

    result.append(list(gaps.values())[0][3])

    for class_name in gaps:

        gap = gaps[class_name]

        result.append({class_name: list(gap)})

    results.append(result)

In [54]:
results

[['bernice',
  'N/A',
  'supervised',
  0.6182363667591202,
  {'African': 0.6452756073609154},
  {'Arab': 0.5564253098499674},
  {'Asian': 0.47878787878787876},
  {'Hispanic': 0.4615384615384615},
  {'Caucasian': 0.512743628185907},
  0.5800000000000001,
  {'non-toxic': ['Caucasian', 'Hispanic', 0.42, 0.5800000000000001]},
  {'toxic': ['Hispanic',
    'Caucasian',
    0.35663082437275984,
    0.6433691756272402]}],
 ['bertweet',
  'N/A',
  'supervised',
  0.697002997002997,
  {'African': 0.7357954545454546},
  {'Arab': 0.5202741290691034},
  {'Asian': 0.5520833333333334},
  {'Hispanic': 0.7321428571428572},
  {'Caucasian': 0.6533333333333333},
  0.23076923076923084,
  {'non-toxic': ['Hispanic', 'Arab', 0.7692307692307692, 0.23076923076923084]},
  {'toxic': ['African', 'Asian', 0.26538629419467474, 0.7346137058053253]}]]