In [None]:
# Exp4_optim1.ipynb

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from pandas.core.frame import DataFrame

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

from sklearn.datasets import fetch_openml
from sklearn import metrics
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

from fairlearn.metrics import MetricFrame
import ipdb
import argparse

seed = 8

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
batch_size = 320
random_state = seed
# shuffle_dataset = True

In [4]:
# ------------- data ---------------

In [5]:
# ------------- source -------------

sens_attr = 'sex'
predict_attr = 'is_higher_than_50k'

src = fetch_openml(data_id=1590)

src.data[predict_attr] = src.target.map({'<=50K':0, '>50K':1})
header = list(src.data.columns)
src.data.dropna(inplace=True)
src.data[predict_attr] = src.data[predict_attr].astype(int)

src.data = src.data[src.data['native-country']!='United-States']
y_true_src = src.data[predict_attr].values

sensitive_attr_src = src.data[sens_attr]
A1_true = pd.get_dummies(sensitive_attr_src)
A1_true = A1_true.drop(['Male'], axis=1)
A1_true = A1_true['Female'].values

header.remove(sens_attr)
header.remove(predict_attr)

src.data = src.data[header]

X_src = pd.get_dummies(src.data)
X_src = X_src.sort_index(axis=1)



# ------------- target -------------

tgt = fetch_openml(data_id=1590)

header = list(tgt.data.columns)

tgt.data[predict_attr] = tgt.target.map({'<=50K':0, '>50K':1})
tgt.data.dropna(inplace=True)
tgt.data[predict_attr] = tgt.data[predict_attr].astype(int)
tgt.data = tgt.data[tgt.data['native-country']=='United-States']
y_true_tgt = tgt.data[predict_attr].values

sensitive_attr_tgt = tgt.data[sens_attr]

header.remove(sens_attr)

tgt.data = tgt.data[header]

X_tgt = pd.get_dummies(tgt.data)
X_tgt = X_tgt.sort_index(axis=1)

n_classes = y_true_tgt.max()+1

In [6]:
# ------------- preprocess ---------------

In [7]:
class PandasDataSet(TensorDataset):

    def __init__(self, *dataframes):
        tensors = (self._df_to_tensor(df) for df in dataframes)
        super(PandasDataSet, self).__init__(*tensors)

    def _df_to_tensor(self, df):
        if isinstance(df, np.ndarray):
            return torch.from_numpy(df).float()
        return torch.from_numpy(df.values).float()

# ------------- source -------------
# if domain=='source': train/test = 0.6:0.4
train_ratio = 0.6
test_ratio = 0.4

indict_src = np.arange(sensitive_attr_src.shape[0])
(X_train_src, X_test_src, y_train_src, y_test_src, A1_train, A1_test, ind_train_src, ind_test_src) = train_test_split(X_src, y_true_src, A1_true, indict_src, test_size=test_ratio, stratify=y_true_src, random_state=random_state)

# processed_X_train_src = X_train_src

# standardize the data
scaler_src = StandardScaler().fit(X_train_src)
X_train_src = scaler_src.transform(X_train_src)
X_test_src = scaler_src.transform(X_test_src)

train_data_src = PandasDataSet(X_train_src, y_train_src, A1_train, ind_train_src)
test_data_src = PandasDataSet(X_test_src, y_test_src, A1_test, ind_test_src)

src_train_loader = DataLoader(train_data_src, batch_size=batch_size, shuffle=True, drop_last=True)

print('# source training samples:', len(train_data_src))
print('# source batches:', len(src_train_loader))

# ------------- target -------------
# else: domain=='target': train/valid/test set = 0.5:0.25:0.25
train_ratio = 0.5
test_ratio = 0.25
valid_ratio = 0.25

# split into train/test set
indict_tgt = np.arange(sensitive_attr_tgt.shape[0])
(X_train_tgt, X_test_tgt, y_train_tgt, y_test_tgt, ind_train_tgt, ind_test_tgt) = train_test_split(X_tgt, y_true_tgt, indict_tgt, test_size=test_ratio, stratify=y_true_tgt, random_state=random_state)

# split training set into train/validation set
(X_train_tgt, X_valid_tgt, y_train_tgt, y_valid_tgt, ind_train_tgt, ind_valid_tgt) = train_test_split(X_train_tgt, y_train_tgt, ind_train_tgt, test_size=valid_ratio/(train_ratio+valid_ratio), stratify=y_train_tgt, random_state=random_state)

# processed_X_train_tgt = X_train_tgt

# standardize the data
scaler_tgt = StandardScaler().fit(X_train_tgt)
X_train_tgt = scaler_tgt.transform(X_train_tgt)
X_valid_tgt = scaler_tgt.transform(X_valid_tgt)
X_test_tgt = scaler_tgt.transform(X_test_tgt)

train_data_tgt = PandasDataSet(X_train_tgt, y_train_tgt, ind_train_tgt)
test_data_tgt = PandasDataSet(X_test_tgt, y_test_tgt, ind_test_tgt)
valid_data_tgt = PandasDataSet(X_valid_tgt, y_valid_tgt, ind_valid_tgt)

tgt_train_loader = DataLoader(train_data_tgt, batch_size=batch_size, shuffle=True, drop_last=True)

print('# target training samples:', len(train_data_tgt))
print('# target batches:', len(tgt_train_loader))


# source training samples: 2358
# source batches: 7
# target training samples: 20646
# target batches: 64


In [8]:
class Mapper(nn.Module):
    """
        Mapping feature dimensions of src and tgt domain into the same.
    """
    def __init__(self, ori_dims, n_features=16):
        super().__init__()
        self.map = nn.Sequential(
            nn.Linear(ori_dims, n_features),
        )
        
    def forward(self, x):
        m = self.map(x).squeeze()
        return m

In [9]:
class Encoder(nn.Module):
    """
        Hidden Embedding Layer h
    """
    def __init__(self, n_features=16, n_hidden=32, p_dropout=0.2):
        super().__init__()
        self.emb = nn.Sequential(
            nn.Linear(n_features, n_hidden*2),
            nn.ReLU(),
            nn.Dropout(p_dropout),
            nn.Linear(n_hidden*2, n_hidden),
            nn.ReLU(),
            nn.Dropout(p_dropout),
#             nn.Linear(n_hidden, n_class),
        )

    def forward(self, m):
        emb = self.emb(m)
        return emb

In [10]:
class Classifier(nn.Module):
    """
        Classification head
    """
    def __init__(self, n_hidden=32, n_class=2):
        super().__init__()
        self.cls = nn.Sequential(
            nn.Linear(n_hidden, n_class),
        )
    def forward(self, emb):
        cls = self.cls(emb)
        return cls

In [11]:
from torch.autograd import Function

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg()

        return output, None


In [12]:
class Discriminator(nn.Module):
    """
        Discrimination head
    """
    def __init__(self, n_hidden=32, n_adv=2):
        super().__init__()
        self.adv = nn.Sequential(
            nn.Linear(n_hidden, n_adv),
        )
    def forward(self, emb):
        reversed_input = ReverseLayerF.apply(emb)
        adv = self.adv(reversed_input)
        return adv

In [13]:
ori_dims_src = X_src.shape[1]
ori_dims_tgt = X_tgt.shape[1]
lr = 0.0001

# n_features = X.shape[1]

M1 = Mapper(ori_dims=ori_dims_src).to(DEVICE)
M2 = Mapper(ori_dims=ori_dims_tgt).to(DEVICE)
H1 = Encoder().to(DEVICE)
H2 = Encoder().to(DEVICE)
FA = Classifier().to(DEVICE)
L2 = Classifier().to(DEVICE)
D1 = Discriminator().to(DEVICE)
D2 = Discriminator().to(DEVICE)

L_params = list(M1.parameters()) + list(M2.parameters()) + list(H1.parameters()) + list(FA.parameters()) + list(D1.parameters())
R_params = list(H2.parameters()) + list(L2.parameters()) + list(D2.parameters())

L_optim = optim.Adam(L_params, lr = lr)
R_optim = optim.Adam(R_params, lr = lr)


In [14]:
def FairDA_transfer_train(M1, M2, H1, FA, D1, src_train_loader, L_optim, criterion):

    for batch_idx, (src, tgt) in enumerate(zip(src_train_loader, tgt_train_loader)):
        X1_train, _, A1_train, _ = src
        X2_train, Y2_train, _ = tgt
        
        X1 = X1_train.to(torch.float32).to(DEVICE)
        X2 = X2_train.to(torch.float32).to(DEVICE)
        A1 = A1_train.to(torch.float32).to(DEVICE)
        
        L_optim.zero_grad()
        
        m1 = M1(X1)
        m2 = M2(X2)
        m12 = torch.cat([m1, m2], dim=0)
        h1 = H1(m12)
        
        
        # ---------------------------
        #   Train: FA, H1, M1, M2
        # ---------------------------
        A1_hat = FA(h1[:batch_size])
        loss_FA = criterion(A1_hat, A1.long())
        
        # --------------
        #    Train D1
        # --------------
        D1_hat = D1(h1)
        
        D_src = torch.zeros(X1.shape[0]).to(DEVICE)
        D_tgt = torch.ones(X2.shape[0]).to(DEVICE)
        D_labels = torch.cat([D_src, D_tgt], dim=0).long()
        
        loss_D1 = criterion(D1_hat, D_labels)
        
        # -------------------
        #    Transfer loss
        # -------------------
        loss_left = loss_FA + loss_D1
        loss_left.backward()
        
        L_optim.step()
        

    return M1, M2, H1, FA, D1

In [15]:
def FairDA_debias_train(H2, L2, D2, FA, H1, M2, tgt_train_loader, R_optim, criterion, D2_weights):
    for batch_idx, (X2_train, Y2_train, _) in enumerate(tgt_train_loader):
        
        X2 = X2_train.to(torch.float32).to(DEVICE)
        Y2 = Y2_train.to(torch.float32).to(DEVICE)
        
        R_optim.zero_grad()

        h2 = H2(M2(X2))

        # ---------------------------
        #   Train: H2, L2
        # ---------------------------
        Y2_hat = L2(h2)
        loss_L2 = criterion(Y2_hat, Y2.long())

        # --------------
        #    Train D2
        # --------------
        A2_hat = D2(h2)
        A2_ground = FA(H1(M2(X2))).detach()
        A2_ground = A2_ground.argmax(dim=1)
        loss_D2 = criterion(A2_hat, A2_ground)
        
        # --------------------
        #    Debiasing loss
        # --------------------
        loss_right = loss_L2 + D2_weights * loss_D2
        loss_right.backward()

        R_optim.step()
        
        
    return H2, L2, D2

In [16]:
def fair_metric(sens_ind, y2_pred, y2_true):

    group_dp = []
    group_equal_odds = []
    sens_data = sensitive_attr_tgt.iloc[sens_ind]


    for sens_value in set(sens_data):
        y_sense_pred = y2_pred[(sens_data==sens_value).values]
        y_sense_test = y2_true[(sens_data==sens_value).values]
        sens_dp = []
        sens_eo = []

        for label in set(y2_true):
            if label>0:
                sens_dp_label = (y_sense_pred==label).sum()/y_sense_pred.shape[0]
                sens_eo_label = (y_sense_pred[y_sense_test==label]==label).sum()/(y_sense_test==label).sum()

                sens_dp.append(sens_dp_label)
                sens_eo.append(sens_eo_label)

        group_dp.append(sens_dp)
        group_equal_odds.append(sens_eo)

    group_dp = np.array(group_dp)
    group_eo = np.array(group_equal_odds)

    dp_diff = np.mean(np.absolute(group_dp-np.mean(group_dp, axis=0, keepdims=True)))
    eo_diff = np.mean(np.absolute(group_equal_odds-np.mean(group_equal_odds, axis=0, keepdims=True)))

    return group_dp, group_eo, dp_diff, eo_diff


In [17]:
ce = nn.CrossEntropyLoss()

n_epoch_pre = 150
n_epoch_tgt = 150

D1_weights = 0.01
D2_weights = 0.01

best_acc_pre = 0
best_epoch_pre = -1

# while best_epoch_pre < 50:
    
for epoch in range(1, n_epoch_pre):

    M1 = M1.train()
    M2 = M2.train()
    H1 = H1.train()
    FA = FA.train()
    D1 = D1.train()


    M1, M2, H1, FA, D1 = FairDA_transfer_train(M1, M2, H1, FA, D1, 
                                               src_train_loader, 
                                               L_optim, criterion=ce)
    M1 = M1.eval()
    M2 = M2.eval()
    H1 = H1.eval()
    FA = FA.eval()
    D1 = D1.eval()

    with torch.no_grad():
        pre_a1_test = FA(H1(M1(test_data_src.tensors[0].to(DEVICE))))
        a1_pred = pre_a1_test.argmax(dim=1)
        acc_a1 = accuracy_score(A1_test, a1_pred.cpu())

        if acc_a1 > best_acc_pre:
            best_acc_pre = acc_a1
            best_epoch_pre = epoch
            torch.save(FA,"saved_models/FA.pt")
            torch.save(H1,"saved_models/H1.pt")
            torch.save(M2,"saved_models/M2.pt")

    
print('==================================================================')
print('                 Pretraining on source finished                   ')
print('                       Best epoch: {:04d}                         '.format(best_epoch_pre+1))
print('                 Best ACC on pretraining: {:.4f}                  '.format(best_acc_pre))
print('==================================================================')


FA = torch.load("saved_models/FA.pt")
H1 = torch.load("saved_models/H1.pt")
M2 = torch.load("saved_models/M2.pt")
FA.eval()
H1.eval()
M2.eval()

best_result = {}
best_epoch = -1
best_fair = 100
acc_thrsh = 0.8
F1_thrsh = 0.5

# while best_epoch < 50:
    
for epoch in range(1, n_epoch_tgt):

    H2 = H2.train()
    L2 = L2.train()
    D2 = D2.train()

    H2, L2, D2 = FairDA_debias_train(H2, L2, D2, FA, H1, M2, 
                                     tgt_train_loader, 
                                     R_optim, criterion=ce, 
                                     D2_weights = D2_weights)
    H2 = H2.eval()
    L2 = L2.eval()
    D2 = D2.eval()

    with torch.no_grad():
        pre_y2_val = L2(H2(M2(valid_data_tgt.tensors[0].to(DEVICE))))
        y2_pre_val = pre_y2_val.argmax(dim=1).cpu()
        acc_y2_val = accuracy_score(y_valid_tgt, y2_pre_val)
        f1_y2_val = f1_score(y_valid_tgt, y2_pre_val)
        _, _, dp_diff_val, eo_diff_val = fair_metric(ind_valid_tgt, y2_pre_val, y_valid_tgt)

        pre_y2_test = L2(H2(M2(test_data_tgt.tensors[0].to(DEVICE))))
        y2_pred = pre_y2_test.argmax(dim=1).cpu()
        acc_y2 = accuracy_score(y_test_tgt, y2_pred)
        f1_y2 = f1_score(y_test_tgt, y2_pred)
        group_dp, group_eo, dp_diff, eo_diff = fair_metric(ind_test_tgt, y2_pred, y_test_tgt)

#             print('************ Test result: accuracy_Y2: {:.4f}, F1_Y2: {:.4f}       ***********'.format(acc_y2, f1_y2))
#             print('================================================================================')



        if acc_y2_val > acc_thrsh and f1_y2_val > F1_thrsh:
            if best_fair > dp_diff_val + eo_diff_val:
                best_fair = dp_diff_val + eo_diff_val
                best_result['Epoch'] = epoch
                best_result['ACC'] = acc_y2
                best_result['F1'] = f1_y2
                best_result['DP'] = dp_diff
                best_result['EO'] = eo_diff
                print('Group DP:')
                print(group_dp)
                print('Group EO:')
                print(group_eo)



                 Pretraining on source finished                   
                       Best epoch: 0118                         
                 Best ACC on pretraining: 0.8581                  
Group DP:
[[0.26681066]
 [0.03641208]]
Group EO:
[[0.5306394 ]
 [0.21111111]]
Group DP:
[[0.262491  ]
 [0.03641208]]
Group EO:
[[0.52886325]
 [0.21111111]]
Group DP:
[[0.26580274]
 [0.04055654]]
Group EO:
[[0.53507996]
 [0.2361111 ]]
Group DP:
[[0.2650828 ]
 [0.04174067]]
Group EO:
[[0.53419185]
 [0.24444444]]
Group DP:
[[0.26148307]
 [0.04114861]]
Group EO:
[[0.5315275 ]
 [0.23888889]]
Group DP:
[[0.26191506]
 [0.04233274]]
Group EO:
[[0.5328597 ]
 [0.24444444]]
Group DP:
[[0.26018718]
 [0.04292481]]
Group EO:
[[0.5315275]
 [0.25     ]]
Group DP:
[[0.262347  ]
 [0.04410894]]
Group EO:
[[0.5346359 ]
 [0.25833333]]
Group DP:
[[0.25788337]
 [0.04410894]]
Group EO:
[[0.5306394 ]
 [0.25555557]]
Group DP:
[[0.25572354]
 [0.04499704]]
Group EO:
[[0.5293073]
 [0.2638889]]
Group DP:
[[0.25946724]
 

In [18]:
print('============================ Performace on Test ============================')
if len(best_result) > 0:
    
    print('************         best epoch: {:.4f}                       ***********'.format(best_result['Epoch']))
    print('************         best ACC: {:.4f}, best F1: {:.4f}        ***********'.format(best_result['ACC'], best_result['F1']))
    print('************         best DP: {:.4f}, best EO: {:.4f}        ***********'.format(best_result['DP'], best_result['EO']))
    print('================================================================================')

else:
    print('Please set smaller ACC and F1 thresholds')


************         best epoch: 91.0000                       ***********
************         best ACC: 0.8189, best F1: 0.5879        ***********
************         best DP: 0.0960, best EO: 0.0898        ***********


##### 