In [1]:
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
from FairRanking.datasets.adult import Adult
from FairRanking.datasets.law import Law
from FairRanking.datasets.compas import Compas
from FairRanking.datasets.wiki import Wiki
from FairRanking.models.BaseDirectRanker import convert_data_to_tensors
from FairRanking.models.DirectRankerAdv import DirectRankerAdv
from FairRanking.TrainingFunctions.DirectRankerAdvTrain import train

In [17]:

def rND(prediction, s, step=10, start=10, protected_group_idx=1, non_protected_group_idx=0):
    '''
    Computes the normalized Discounted Difference in PyTorch. This metric measures the disparity in 
    ranking outcomes between protected and non-protected groups in a binary classification context.
    Lower values indicate less disparity.

    Parameters:
    - prediction (torch.Tensor): The model predictions or scores.
    - s (torch.Tensor or list): The group labels (0 or 1), indicating whether each instance belongs 
                                to the protected group or not.
    - step (int): The step size to evaluate the ranking at different cutoffs.
    - start (int): The starting point for evaluating the ranking.
    - protected_group_idx (int): The index representing the protected group in `s`.
    - non_protected_group_idx (int): The index representing the non-protected group in `s`.

    Returns:
    - float: The normalized Discounted Difference score.
    '''

    # Ensure s is a 1D tensor
    #s = torch.as_tensor(s).flatten()

    # Check for size mismatch
    if len(prediction) != len(s):
        raise AssertionError(f'len of prediction {len(prediction)} and s {len(s)} are unequal')

    # Count occurrences of each group
    unique, counts = torch.unique(s, return_counts=True)
    count_dict_all = {k.item(): v.item() for k, v in zip(unique, counts)}
    #print(f"before sorting: {torch.unique(s, return_counts=True)}")
    # Ensure both groups are represented
    keys = [protected_group_idx, non_protected_group_idx]
    for key in keys:
        if key not in count_dict_all:
            count_dict_all[key] = 0

    # Sort predictions and corresponding group labels
    sorted_indices = torch.argsort(prediction, descending=True, dim=0)
    print(sorted_indices)
    sorted_s = s[sorted_indices]
    #print(f"after sorting: {torch.unique(sorted_s, return_counts=True)}")

    # Create 'worst-case' sorted lists for regularization
    # first only the non protected group
    fake_horrible_s = torch.cat([torch.full((count_dict_all[non_protected_group_idx],), non_protected_group_idx),
                                 torch.full((count_dict_all[protected_group_idx],), protected_group_idx)])

    # first only the protected group
    fake_horrible_s_2 = torch.cat([torch.full((count_dict_all[protected_group_idx],), protected_group_idx),
                                   torch.full((count_dict_all[non_protected_group_idx],), non_protected_group_idx)])

    rnd, max_rnd, max_rnd_2 = 0.0, 0.0, 0.0

    for i in range(start, len(s), step):
        # Count occurrences in top i of the sorted list
        unique, counts = torch.unique(sorted_s[:i], return_counts=True)
        count_dict_top_i = {k.item(): v.item() for k, v in zip(unique, counts)}

        unique, counts = torch.unique(fake_horrible_s[:i], return_counts=True)
        count_dict_reg = {k.item(): v.item() for k, v in zip(unique, counts)}

        unique_2, counts_2 = torch.unique(fake_horrible_s_2[:i], return_counts=True)
        count_dict_reg_2 = {k.item(): v.item() for k, v in zip(unique_2, counts_2)}

        for key in keys:
            if key not in count_dict_reg:
                count_dict_reg[key] = 0
            if key not in count_dict_top_i:
                count_dict_top_i[key] = 0
            if key not in count_dict_reg_2:
                count_dict_reg_2[key] = 0
        #print(count_dict_top_i)
        # Update rnd and max_rnd
        rnd += abs(
            count_dict_top_i[protected_group_idx] / i - count_dict_all[protected_group_idx] / len(s))
        max_rnd += abs(
            count_dict_reg[protected_group_idx] / i - count_dict_all[protected_group_idx] / len(s))
        max_rnd_2 += abs(
            count_dict_reg_2[protected_group_idx] / i - count_dict_all[protected_group_idx] / len(s))

    max_rnd = max(max_rnd, max_rnd_2)
    print(rnd)
    print(max_rnd)
    return rnd / max_rnd if max_rnd != 0 else 0

In [2]:
#data = Law('Race','/Users/robert/Desktop/Bachelor/FairRanker/data')
data = Adult('/Users/robert/Desktop/Bachelor/FairRanker/data')
#data = Compas()
#data = Wiki()
(X_train, s_train, y_train), (X_val, s_val, y_val), (X_test, s_test, y_test) = data.get_data()

In [4]:
X_train0, X_train1, s_train0, s_train1, y_train, X_val0, X_val1, s_val0, s_val1, y_val, X_test0, X_test1, s_test0, s_test1, y_test = convert_data_to_tensors(data)

  x0 = torch.tensor(x0, dtype=torch.float32)


In [5]:
X_train0.shape

torch.Size([10128, 55])

In [7]:
from FairRanking.helpers import rND_torch
y_test_full = torch.cat((y_test, (-1)*y_test), dim=0)
s_test_full = torch.cat((s_test0, s_test1), dim=0)
base_rnd = rND(y_test_full, torch.argmax(s_test_full, dim=1))
base_rnd

0.21943661406783785

In [18]:
from FairRanking.helpers import rND_torch
y_test_full = y_train
s_test_full = s_train0
base_rnd = rND(y_test_full, torch.argmax(s_test_full, dim=1))
base_rnd

tensor([[ 5064],
        [10126],
        [10125],
        ...,
        [    2],
        [    1],
        [ 2532]])


81.66092254315369
355.1319701526687


0.22994528627779762

In [36]:
print(torch.argmax(s_test_full, dim=1))

tensor([1, 1, 1,  ..., 0, 1, 0])


In [38]:
y_test_full.squeeze()

tensor([-1., -1., -1.,  ...,  1.,  1.,  1.])

In [14]:
torch.manual_seed(42)
model = DirectRankerAdv(num_features=X_train0.shape[1],
                    kernel_initializer=nn.init.normal_,
                    hidden_layers=[64, 32, 16],
                    bias_layers=[128, 64, 32, 16],
            )

data_train = [[X_train0, X_train1, y_train, s_train0, s_train1],
              [X_test0, X_test1, y_test, s_test0, s_test1]]

train(model, data_train, n_epochs=100, path='./', schedule=[1,1], threshold=0.4, adv_lr=0.001)

Test Loss: 0.0284	 Test Accuracy: 0.9991	 DI: 133.5815
Finished Schedule: [1, 1]
