<h1><center>ERM with DNN under penalty of Equalized Odds</center></h1>

We implement here a regular Empirical Risk Minimization (ERM) of a Deep Neural Network (DNN) penalized to enforce an Equalized Odds constraint. More formally, given a dataset of size $n$ consisting of context features $x$, target $y$ and a sensitive information $a$ to protect, we want to solve
$$
\text{argmin}_{h\in\mathcal{H}}\frac{1}{n}\sum_{i=1}^n \ell(y_i, h(x_i)) + \lambda \chi^2|_1
$$
where $\ell$ is for instance the MSE and the penalty is
$$
\chi^2|_1 = \left\lVert\chi^2\left(\hat{\pi}(h(x)|y, a|y), \hat{\pi}(h(x)|y)\otimes\hat{\pi}(a|y)\right)\right\rVert_1
$$
where $\hat{\pi}$ denotes the empirical density estimated through a Gaussian KDE.

### Imports

In [1]:
import sys, os
sys.path.append(os.path.abspath(os.path.join('../..')))
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import numpy as np

from examples.data_loading import read_dataset

### The dataset

We use here the _communities and crimes_ dataset that can be found on the UCI Machine Learning Repository (http://archive.ics.uci.edu/ml/datasets/communities+and+crime). Non-predictive information, such as city name, state... have been removed and the file is at the arff format for ease of loading.

In [3]:
x_train, y_train, a_train, x_test, y_test, a_test = read_dataset(name='crimes', fold=1)
n, d = x_train.shape

### The Deep Neural Network

We define a very simple DNN for regression here

In [4]:
class NetRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NetRegression, self).__init__()
        size = 50
        self.first = nn.Linear(input_size, size)
        self.fc = nn.Linear(size, size)
        self.last = nn.Linear(size, num_classes)

    def forward(self, x):
        out = F.selu(self.first(x))
        out = F.selu(self.fc(out))
        out = self.last(out)
        out = torch.sigmoid(out) # NEW
        return out

### The fairness-penalized ERM

We now implement the full learning loop. The regression loss used is the quadratic loss with a L2 regularization and the fairness-inducing penalty.

In [12]:
def regularized_learning(x_train, y_train, a_train, model, fairness_metric_train, fairness_metric_test, fairness_weight = 1.0, lr=1e-5, num_epochs=10, print_progress = True):    
    X = torch.tensor(x_train.astype(np.float32))
    A = torch.tensor(a_train.astype(np.float32))
    Y = torch.tensor(y_train.astype(np.float32))
    dataset = data_utils.TensorDataset(X, Y, A)
    dataset_loader = data_utils.DataLoader(dataset=dataset, batch_size=200, shuffle=True)

    # mse regression objective
    data_fitting_loss = nn.MSELoss()

    # stochastic optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)

    for j in range(num_epochs):
        if print_progress:
            print(f"EPOCH {j} started")
        for i, (x, y, a) in enumerate(dataset_loader):
            # if print_progress:
            #    print(f"Batch {i} started")
            def closure():
                optimizer.zero_grad()
                prediction = model(x).flatten()
                loss = fairness_weight * fairness_metric_train(prediction, a, y) + data_fitting_loss(prediction, y)
                loss.backward()
                """for name, param in model.named_parameters():
                    if param.grad is not None:
                        print(f"Parameter: {name}\nGradient: {param.grad}\n")"""
                return loss
            optimizer.step(closure)
        mse_curr_test, nd_curr_test = evaluate(model, x_test, y_test, a_test, fairness_metric=fairness_metric_test)
        print(f"TEST -- mse: {mse_curr_test}, nd: {nd_curr_test}, combined: {mse_curr_test + fairness_weight * nd_curr_test}")
        mse_curr_train, nd_curr_train = evaluate(model, x_train, y_train, a_train, fairness_metric=fairness_metric_test)
        print(f"TRAIN -- mse: {mse_curr_train}, nd: {nd_curr_train}, combined: {mse_curr_train + fairness_weight * nd_curr_train}")
    return model

### Evaluation

For the evaluation on the test set, we compute two metrics: the MSE (accuracy) and HGR$|_\infty$ (fairness).

In [13]:
def evaluate(model, x, y, a, fairness_metric):
    X = torch.tensor(x.astype(np.float32))
    A = torch.Tensor(a.astype(np.float32))
    Y = torch.tensor(y.astype(np.float32))

    prediction = model(X).detach().flatten()
    loss = nn.MSELoss()(prediction, Y)
    discrimination = fairness_metric(prediction, A, Y)
    return loss.item(), discrimination

### Running everything together


In [14]:
def generate_fairness_metric(constrained_intervals_A, quantizition_intervals_Y, train, size_compensation = lambda x: np.sqrt(x)):
    
    def inside(num, endpoints):
        start, end = endpoints
        return start <= num and num < end
    
    def fairness_metric(Y_hat, A, Y):
        nd_losses = []
        n = len(Y_hat)
        for inter_Y in quantizition_intervals_Y:
            for inter_A in constrained_intervals_A:
                cnt_y_a = 0
                cnt_y = 0
                sum_y_yhat = 0
                sum_y_a_yhat = 0
                for i in range(len(Y_hat)): # could be sped up by combining with outer loop
                    if inside(Y[i], inter_Y):
                        cnt_y += 1
                        sum_y_yhat += Y_hat[i]
                        if inside(A[i], inter_A):
                            cnt_y_a += 1
                            sum_y_a_yhat += Y_hat[i]
                if cnt_y_a > 0 and cnt_y > 0:
                    curr_nd_loss = torch.abs(sum_y_a_yhat / cnt_y_a - sum_y_yhat / cnt_y) * size_compensation(cnt_y_a / n)
                    nd_losses.append(curr_nd_loss)
        nd_losses_torch = torch.stack(nd_losses)
        return torch.mean(nd_losses_torch) if train else torch.max(nd_losses_torch)
    return fairness_metric

In [15]:
def generate_constrained_intervals(num_constrained_intervals):
    endpoints = np.linspace(0, 1, num_constrained_intervals + 1)
    constrained_intervals = []
    for i in range(len(endpoints) - 1):
        constrained_intervals.append((endpoints[i], endpoints[i + 1]))
    return constrained_intervals

In [18]:
%%time
model = NetRegression(d, 1)
num_epochs = 200
lr = 1e-5
fairness_weight = 1
num_constrained_intervals = 2
intervals = generate_constrained_intervals(num_constrained_intervals)
fairness_metric_train = generate_fairness_metric(intervals, intervals, True)
fairness_metric_test = generate_fairness_metric(intervals, intervals, False)

model = regularized_learning(x_train, y_train, a_train, model=model, fairness_metric_train=fairness_metric_train, fairness_metric_test=fairness_metric_test, lr=lr, \
                             num_epochs=num_epochs, fairness_weight=fairness_weight)
mse, discrimination = evaluate(model, x_test, y_test, a_test, fairness_metric=fairness_metric_test)

EPOCH 0 started
TEST -- mse: 0.13169682025909424, nd: 0.018028298392891884, combined: 0.14972512423992157
TRAIN -- mse: 0.13208803534507751, nd: 0.014867031946778297, combined: 0.14695507287979126
EPOCH 1 started
TEST -- mse: 0.1298317015171051, nd: 0.01720607280731201, combined: 0.14703777432441711
TRAIN -- mse: 0.13022999465465546, nd: 0.014116917736828327, combined: 0.1443469077348709
EPOCH 2 started
TEST -- mse: 0.12800323963165283, nd: 0.01639064960181713, combined: 0.1443938910961151
TRAIN -- mse: 0.12840695679187775, nd: 0.013366438448429108, combined: 0.14177340269088745
EPOCH 3 started
TEST -- mse: 0.1262025088071823, nd: 0.015578131191432476, combined: 0.14178064465522766
TRAIN -- mse: 0.12660987675189972, nd: 0.012618571519851685, combined: 0.1392284482717514
EPOCH 4 started
TEST -- mse: 0.12446454912424088, nd: 0.014787674881517887, combined: 0.13925223052501678
TRAIN -- mse: 0.12487106025218964, nd: 0.01188971009105444, combined: 0.13676077127456665
EPOCH 5 started
TEST --

TEST -- mse: 0.08600173145532608, nd: 0.005367509555071592, combined: 0.09136924147605896
TRAIN -- mse: 0.08651749044656754, nd: 0.005812514573335648, combined: 0.09233000874519348
EPOCH 43 started
TEST -- mse: 0.08550184220075607, nd: 0.005306053441017866, combined: 0.09080789238214493
TRAIN -- mse: 0.08600928634405136, nd: 0.0059921215288341045, combined: 0.09200140833854675
EPOCH 44 started
TEST -- mse: 0.0849839597940445, nd: 0.0052267382852733135, combined: 0.0902106985449791
TRAIN -- mse: 0.08548987656831741, nd: 0.006220520474016666, combined: 0.0917103961110115
EPOCH 45 started
TEST -- mse: 0.08449902385473251, nd: 0.005157111212611198, combined: 0.08965613692998886
TRAIN -- mse: 0.08500072360038757, nd: 0.006417182274162769, combined: 0.09141790866851807
EPOCH 46 started
TEST -- mse: 0.08400703221559525, nd: 0.0053914510644972324, combined: 0.08939848095178604
TRAIN -- mse: 0.08450458943843842, nd: 0.006640880834311247, combined: 0.09114547073841095
EPOCH 47 started
TEST -- ms

TEST -- mse: 0.07063963264226913, nd: 0.010678439401090145, combined: 0.08131807297468185
TRAIN -- mse: 0.0709143728017807, nd: 0.01168497372418642, combined: 0.08259934931993484
EPOCH 85 started
TEST -- mse: 0.07035481184720993, nd: 0.010836594738066196, combined: 0.08119140565395355
TRAIN -- mse: 0.07063020765781403, nd: 0.011822276748716831, combined: 0.08245248347520828
EPOCH 86 started
TEST -- mse: 0.07008422911167145, nd: 0.010951923206448555, combined: 0.08103615045547485
TRAIN -- mse: 0.07035385817289352, nd: 0.011931453831493855, combined: 0.0822853147983551
EPOCH 87 started
TEST -- mse: 0.06984204798936844, nd: 0.01100276131182909, combined: 0.08084481209516525
TRAIN -- mse: 0.07010448724031448, nd: 0.01198981236666441, combined: 0.08209429681301117
EPOCH 88 started
TEST -- mse: 0.0695854052901268, nd: 0.011089337058365345, combined: 0.08067474514245987
TRAIN -- mse: 0.06983985006809235, nd: 0.012076295912265778, combined: 0.08191614598035812
EPOCH 89 started
TEST -- mse: 0.0

TRAIN -- mse: 0.06181276589632034, nd: 0.014085293747484684, combined: 0.07589805871248245
EPOCH 126 started
TEST -- mse: 0.06155407056212425, nd: 0.013267943635582924, combined: 0.07482201606035233
TRAIN -- mse: 0.06162012740969658, nd: 0.014133413322269917, combined: 0.07575353980064392
EPOCH 127 started
TEST -- mse: 0.06135977804660797, nd: 0.013335630297660828, combined: 0.0746954083442688
TRAIN -- mse: 0.0614241361618042, nd: 0.014192293398082256, combined: 0.07561642676591873
EPOCH 128 started
TEST -- mse: 0.06118624657392502, nd: 0.013354513794183731, combined: 0.07454076409339905
TRAIN -- mse: 0.06124812364578247, nd: 0.01421119924634695, combined: 0.075459323823452
EPOCH 129 started
TEST -- mse: 0.061005305498838425, nd: 0.013395465910434723, combined: 0.07440076768398285
TRAIN -- mse: 0.06106339395046234, nd: 0.014256039634346962, combined: 0.07531943172216415
EPOCH 130 started
TEST -- mse: 0.06081712618470192, nd: 0.013455179519951344, combined: 0.07427230477333069
TRAIN -- 

TEST -- mse: 0.054835930466651917, nd: 0.014606069773435593, combined: 0.06944200396537781
TRAIN -- mse: 0.054799530655145645, nd: 0.015338506549596786, combined: 0.07013803720474243
EPOCH 168 started
TEST -- mse: 0.05468929931521416, nd: 0.014626688323915005, combined: 0.06931598484516144
TRAIN -- mse: 0.054653216153383255, nd: 0.015352503396570683, combined: 0.07000572234392166
EPOCH 169 started
TEST -- mse: 0.054551221430301666, nd: 0.014653043821454048, combined: 0.06920426338911057
TRAIN -- mse: 0.054507747292518616, nd: 0.015370426699519157, combined: 0.06987817585468292
EPOCH 170 started
TEST -- mse: 0.054413776844739914, nd: 0.014648287557065487, combined: 0.06906206160783768
TRAIN -- mse: 0.05436701700091362, nd: 0.015370508655905724, combined: 0.0697375237941742
EPOCH 171 started
TEST -- mse: 0.054273821413517, nd: 0.014640258625149727, combined: 0.06891407817602158
TRAIN -- mse: 0.0542263500392437, nd: 0.01537407748401165, combined: 0.0696004256606102
EPOCH 172 started
TEST 

In [11]:
print("MSE:{} Beta loss:{}".format(mse, discrimination))

MSE:0.06243539974093437 Beta loss:0.015345253981649876
