In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
import tensorflow_constrained_optimization as tfco

from utils import read_twitter_data, save_logs, load_logs, save_embeddings, load_embeddings, is_available
from embeddings import get_bert_embeddings
from model import get_dense_model
from train import train_model, create_tensors
from metrics import error_rate, group_false_positive_rates, f1, false_negative_equality_diff, false_positive_equality_diff
from evaluation import eval_report, mcnemar_test
from plot import plot_perf
from configs import config as cf

## Read Dataset

In [None]:
data = read_twitter_data()

print("Overall toxicity proportion = {0:.2f}%".format(data['target'].mean() * 100))
for i in cf.identity_keys_twitter:
    print("\t{} proportion = {:.2f}% | toxicity proportion in {} = {:.2f}%".format(i, data[i].mean()*100, i, data[data[i]]['target'].mean()*100))

In [None]:
if is_available(cf.twitter_embeddings_path):
    sentence_embeddings = load_embeddings(dataset='twitter')
else:
    sentence_embeddings = get_bert_embeddings(data['comment'])
    save_embeddings(sentence_embeddings, dataset='twitter')

## Train/Val/Test split

In [None]:
train_df, val_test_df = train_test_split(data, train_size=cf.train_size, random_state=cf.random_state, shuffle=True)
val_df, test_df = train_test_split(val_test_df, train_size=cf.val_test_ratio, random_state=cf.random_state, shuffle=True)

train_labels = np.array(train_df['target']).reshape(-1, 1).astype(float)
val_labels = np.array(val_df['target']).reshape(-1, 1).astype(float)
test_labels = np.array(test_df['target']).reshape(-1, 1).astype(float)

train_groups = np.array(train_df[cf.identity_keys_twitter]).astype(int)
val_groups = np.array(val_df[cf.identity_keys_twitter]).astype(int)
test_groups = np.array(test_df[cf.identity_keys_twitter]).astype(int)

train_relevant_obs_indices = np.where(train_df[cf.identity_keys_twitter].sum(axis=1))[0]

train, val_test = train_test_split(sentence_embeddings, train_size=cf.train_size, random_state=cf.random_state, shuffle=True)
val, test = train_test_split(val_test, train_size=cf.val_test_ratio, random_state=cf.random_state, shuffle=True)

## Train plain (baseline / unconstrained) model

In [None]:
model = get_dense_model()
plain_model = train_model(model, train, train_labels, val, val_labels, cf.twitter_plain_model_name)

## Evaluate plain (unconstrained) model on test data

In [None]:
plain_model = get_dense_model()
plain_model.load_weights('{}/{}.h5'.format(cf.MODELS_DIR, cf.twitter_plain_model_name))
test_preds_plain = plain_model.predict_classes(test, batch_size=cf.hyperparams['batch_size'])
test_probs_plain = plain_model.predict(test, batch_size=cf.hyperparams['batch_size'])

eval_report(test_labels, test_preds_plain, test_probs_plain, test_groups)
plot_perf(test_labels, test_preds_plain, test_groups, cf.identity_keys_twitter, 'Plain model')

## Train fairness constrained model

In [None]:
(feat_tensor, feat_tensor_group, label_tensor, label_tensor_group, group_tensor) = create_tensors(cf.num_identities_twitter)

constrained_model = get_dense_model()

def predictions():
  return constrained_model(feat_tensor)

def predictions_group():
  return constrained_model(feat_tensor_group)

In [None]:
# define separate contexts for overall training data and groups of interest
context = tfco.rate_context(predictions, lambda: label_tensor)
context_group = tfco.rate_context(predictions_group, lambda: label_tensor_group)

# define the objective = minimize negative of f1 score
objective = -1 * tfco.f_score(context)

# list group-specific FNRs/FPRs
fnrs = []
fprs = []
constraints = []
for iden in range(cf.num_identities_twitter):
    context_group_subset = context_group.subset(lambda kk=iden: group_tensor[:, kk] > 0)
    fnrs.append(tfco.false_negative_rate(context_group_subset))
    fprs.append(tfco.false_positive_rate(context_group_subset))

# define lower and upper bound constraints (see equation 3 in paper)
constraints.append(tfco.upper_bound(fnrs) - tfco.false_negative_rate(context) <= cf.twitter_allowed_fnr_deviation)
constraints.append(tfco.upper_bound(fprs) - tfco.false_positive_rate(context) <= cf.twitter_allowed_fpr_deviation)
constraints.append(tfco.false_negative_rate(context) - tfco.lower_bound(fnrs) <= cf.twitter_allowed_fnr_deviation)
constraints.append(tfco.false_positive_rate(context) - tfco.lower_bound(fprs) <= cf.twitter_allowed_fpr_deviation)

# define problem, optimizer and variables to optimize
problem = tfco.RateMinimizationProblem(objective, constraints)
optimizer = tfco.ProxyLagrangianOptimizerV2(
    optimizer=tf.keras.optimizers.Adam(cf.hyperparams['lr']),
    constraint_optimizer=tf.keras.optimizers.Adam(cf.hyperparams['lr_constraints']),
    num_constraints=problem.num_constraints)
var_list = (constrained_model.trainable_weights + problem.trainable_variables + optimizer.trainable_variables())

In [None]:
num_obs = train.shape[0]
num_obs_sen = train_relevant_obs_indices.shape[0]

# define checkpoint frequency
num_steps = int(num_obs / cf.hyperparams['batch_size'])
skip_steps = int(num_steps / 3)

# list of recorded objectives and constraint violations for validation set
error_list = []
f1_list = []
fped_list = []
fned_list = []
violations_list = []

start_time = time.time()
model_counter = 0
for ep in range(cf.hyperparams['epochs']):  # loop over epochs
    perm = np.random.permutation(train.shape[0]) # shuffle data
    train, train_labels = train[perm], train_labels[perm]
    for batch_index in range(num_steps):  # loop over minibatches
        # training data indices of overall stream
        batch_indices = np.arange(batch_index * cf.hyperparams['batch_size'], (batch_index + 1) * cf.hyperparams['batch_size'])
        batch_indices = [ind % num_obs for ind in batch_indices]

        # training data indices of group stream
        batch_indices_group = np.arange(batch_index * cf.hyperparams['batch_size'], (batch_index + 1) * cf.hyperparams['batch_size'])
        batch_indices_group = [train_relevant_obs_indices[ind % num_obs_sen] for ind in batch_indices_group]

        # assign training data features, labels, groups from the minibatches to the respective tensors
        feat_tensor.assign(train[batch_indices, :])
        label_tensor.assign(train_labels[batch_indices])
        feat_tensor_group.assign(train[batch_indices_group, :])
        label_tensor_group.assign(train_labels[batch_indices_group])
        group_tensor.assign(train_groups[batch_indices_group, :])

        # gradient update
        optimizer.minimize(problem, var_list=var_list)

        # snapshot model parameters, evaluate objective and constraint violations on validation set
        if batch_index % skip_steps == 0:

            val_scores = constrained_model.predict_classes(val)

            fped_list.append(false_positive_equality_diff(val_labels, val_scores, val_groups))
            fned_list.append(false_negative_equality_diff(val_labels, val_scores, val_groups))
            violations_list.append(fped_list[-1] + fned_list[-1])
            error_list.append(error_rate(val_labels, val_scores))
            f1_list.append(f1(val_labels, val_scores))

            # save model weights
            constrained_model.save_weights('{}/{}_{}.h5'.format(cf.MODELS_DIR, cf.twitter_constrained_model_name, model_counter))
            model_counter += 1

        # display most recently recorded objective and constraint violation for validation set
        elapsed_time = time.time() - start_time
        sys.stdout.write(
            '\rEpoch {}/{} | iter {}/{} | total elapsed time = {:.0f} secs | current error rate (val) = {:.4f} | current f1 (val) = {:.4f} | current total bias (val) = {:.4f}'.format(
            ep + 1, cf.hyperparams['epochs'], batch_index + 1, num_steps, elapsed_time, error_list[-1], f1_list[-1], violations_list[-1]))
print('\ņTraining finalized.')
save_logs(error_list, fped_list, fned_list, f1_list, cf.twitter_log_name)

## Investigate discovered solutions interactively

### Note that there may be several solutions that satisfy the constraints.

In [None]:
import plotly.graph_objects as go

error_list, fped_list, fned_list, f1_list = load_logs(cf.twitter_log_name)
violations_list = fped_list + fned_list

fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(len(violations_list)), y=violations_list, mode='lines+markers', name='bias'))
fig.add_trace(go.Scatter(x=np.arange(len(f1_list)), y=f1_list, mode='lines+markers', name='f1'))
fig.show()

nonzero_f1 = np.where(f1_list==0, np.inf, f1_list)
nonzero_bias = np.where(violations_list==0, np.inf, violations_list)
print('Min bias = {:.2f} is at index {}. f1_score of that model is {:.3f} (warning: first iterations may give very low bias whole having very low f1 score)'.format(100 * min(nonzero_bias), np.argmin(nonzero_bias), f1_list[np.argmin(nonzero_bias)]))
print('Max f1 = {:.3f} is at index {}. Bias of that model is {:.2f}.'.format(max(f1_list), np.argmax(f1_list), 100 * violations_list[np.argmax(f1_list)]))

## Evaluate fairness constrained model on test data

In [None]:
selected_index = 130  # select one of the discovered models
constrained_model = get_dense_model()
constrained_model.load_weights('{}/{}_{}.h5'.format(cf.MODELS_DIR, cf.twitter_constrained_model_name, selected_index))

test_preds_const = constrained_model.predict_classes(test, batch_size=cf.hyperparams['batch_size'])
test_probs_const = constrained_model.predict(test, batch_size=cf.hyperparams['batch_size'])

eval_report(test_labels, test_preds_const, test_probs_const, test_groups)
plot_perf(test_labels, test_preds_const, test_groups, cf.identity_keys_twitter, 'Constrained model')

## Compare unconstrained vs. fairness constrained model statistically

In [None]:
mcnemar_test(labels=test_labels.ravel(), model1_preds=test_preds_plain.ravel(), model1_name='Baseline model', 
                                         model2_preds=test_preds_const.ravel(), model2_name='Constrained model')