# FR-Train on poisoned synthetic data

## Import libraries

In [1]:
import sys, os
import numpy as np
import math

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

import math
import matplotlib.pyplot as plt

from argparse import Namespace

from FRTrain_arch import Generator, DiscriminatorF, DiscriminatorR, weights_init_normal, test_model

import warnings
warnings.filterwarnings("ignore")

## Load and process data (using poisoned y train label)

In [2]:
# a namespace object which contains some of the hyperparameters
opt = Namespace(num_train=2000, num_val1=200, num_val2=500, num_test=1000)

In [3]:
num_train = opt.num_train
num_val1 = opt.num_val1
num_val2 = opt.num_val2
num_test = opt.num_test

X = np.load('X_synthetic.npy') # Input features
y = np.load('y_synthetic.npy') # Original labels
y_poi = np.load('y_poi.npy') # Poisoned train labels
s1 = np.load('s1_synthetic.npy') # Sensitive features

X = torch.FloatTensor(X)
y = torch.FloatTensor(y)
y_poi = torch.FloatTensor(y_poi)
s1 = torch.FloatTensor(s1)

X_train = X[:num_train - num_val1]
y_train = y_poi[:num_train - num_val1] # Poisoned label
s1_train = s1[:num_train - num_val1]

X_val = X[num_train: num_train + num_val1]
y_val = y[num_train: num_train + num_val1]
s1_val = s1[num_train: num_train + num_val1]

# Currently not used
# X_val2 = X[num_train + num_val1 : num_train + num_val1 + num_val2]
# y_val2 = y[num_train + num_val1 : num_train + num_val1 + num_val2]
# s1_val2 = s1[num_train + num_val1 : num_train + num_val1 + num_val2]

X_test = X[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
y_test = y[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
s1_test = s1[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]

XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

In [4]:
print("--------------------- Number of Data -------------------------" )
print(
    "Train data : %d, Validation data : %d, Test data : %d "
    % (len(y_train), len(y_val), len(y_test))
)       
print("--------------------------------------------------------------")

--------------------- Number of Data -------------------------
Train data : 1800, Validation data : 200, Test data : 1000 
--------------------------------------------------------------


# Training with clean data

In [5]:
def train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    
    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train
    
    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val
    
    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test
    
    # Saves return values here
    test_result = [] 
    
    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch
    
    # Changes the input validation data to an appropriate shape for the training
    XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator()
    discriminator_F = DiscriminatorF()
    discriminator_R = DiscriminatorR()

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    XSY_val_data = XSY_val[:val]

    train_len = XS_train.shape[0]
    val_len = XSY_val.shape[0]

    # Ground truths using in Disriminator_R
    Tensor = torch.FloatTensor
    valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
    generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)
    

    r_weight = torch.ones_like(y_train, requires_grad=False).float()
    r_ones = torch.ones_like(y_train, requires_grad=False).float()

    for epoch in range(n_epochs):

        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 500:
            optimizer_G.zero_grad()

        gen_y = generator(XS_train)
        gen_data = torch.cat([XS_train, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()
        
        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s1_train)
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()
            
            
        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------

        optimizer_D_R.zero_grad()

        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()

        
        # ---------------------
        #  Updates Generator
        # ---------------------


        if epoch < 500 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, (y_train+1)/2)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            r_decision = discriminator_R(gen_data)
            r_gen = bce_loss(r_decision, generated)
            
            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
            if epoch % 100 == 0:
                loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
                a = 1/(1+torch.exp(-(loss_ratio-3)))
                b = 1-a
                r_weight_tmp = r_decision.detach().squeeze()
                r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s1_train, reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y_train.squeeze()+1)/2, reduction="none").squeeze()

            f_gen = torch.mean(f_cost*r_weight)
            g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

        if epoch % 200 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
                )

#     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')
    tmp = test_model(generator, XS_test, y_test, s1_test)
    test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])

    return test_result

In [6]:
train_result = []
train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

train_opt = Namespace(val=len(y_val), n_epochs=10000, k=5, lr_g=0.001, lr_f=0.001, lr_r=0.001)
seed = 1

lambda_f_set = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.52] # Lambda value for the fairness discriminator of FR-Train.
lambda_r = 0.4 # Lambda value for the robustness discriminator of FR-Train.

for lambda_f in lambda_f_set:
    train_result.append(train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed))

[Lambda: 0.100000] [Epoch 0/10000] [D_F loss: 0.686341] [D_R loss: 0.716233] [G loss: 0.060358]
[Lambda: 0.100000] [Epoch 200/10000] [D_F loss: 0.682480] [D_R loss: 0.705212] [G loss: 0.057829]
[Lambda: 0.100000] [Epoch 400/10000] [D_F loss: 0.679609] [D_R loss: 0.698599] [G loss: 0.057678]
[Lambda: 0.100000] [Epoch 600/10000] [D_F loss: 0.672426] [D_R loss: 0.689105] [G loss: -0.083195]
[Lambda: 0.100000] [Epoch 800/10000] [D_F loss: 0.675748] [D_R loss: 0.686940] [G loss: -0.078719]
[Lambda: 0.100000] [Epoch 1000/10000] [D_F loss: 0.674420] [D_R loss: 0.684453] [G loss: -0.074216]
[Lambda: 0.100000] [Epoch 1200/10000] [D_F loss: 0.673476] [D_R loss: 0.682800] [G loss: -0.070523]
[Lambda: 0.100000] [Epoch 1400/10000] [D_F loss: 0.672513] [D_R loss: 0.681618] [G loss: -0.067542]
[Lambda: 0.100000] [Epoch 1600/10000] [D_F loss: 0.671613] [D_R loss: 0.680717] [G loss: -0.065112]
[Lambda: 0.100000] [Epoch 1800/10000] [D_F loss: 0.670873] [D_R loss: 0.680004] [G loss: -0.063117]
[Lambda: 0

[Lambda: 0.150000] [Epoch 6200/10000] [D_F loss: 0.681877] [D_R loss: 0.678068] [G loss: -0.111400]
[Lambda: 0.150000] [Epoch 6400/10000] [D_F loss: 0.681847] [D_R loss: 0.678019] [G loss: -0.111119]
[Lambda: 0.150000] [Epoch 6600/10000] [D_F loss: 0.681840] [D_R loss: 0.677954] [G loss: -0.110843]
[Lambda: 0.150000] [Epoch 6800/10000] [D_F loss: 0.681796] [D_R loss: 0.677864] [G loss: -0.110574]
[Lambda: 0.150000] [Epoch 7000/10000] [D_F loss: 0.681815] [D_R loss: 0.677781] [G loss: -0.110301]
[Lambda: 0.150000] [Epoch 7200/10000] [D_F loss: 0.681768] [D_R loss: 0.677701] [G loss: -0.110019]
[Lambda: 0.150000] [Epoch 7400/10000] [D_F loss: 0.681719] [D_R loss: 0.677602] [G loss: -0.109760]
[Lambda: 0.150000] [Epoch 7600/10000] [D_F loss: 0.681623] [D_R loss: 0.677501] [G loss: -0.109508]
[Lambda: 0.150000] [Epoch 7800/10000] [D_F loss: 0.681566] [D_R loss: 0.677377] [G loss: -0.109268]
[Lambda: 0.150000] [Epoch 8000/10000] [D_F loss: 0.681487] [D_R loss: 0.677279] [G loss: -0.109035]


[Lambda: 0.250000] [Epoch 2000/10000] [D_F loss: 0.686747] [D_R loss: 0.682007] [G loss: -0.247917]
[Lambda: 0.250000] [Epoch 2200/10000] [D_F loss: 0.686878] [D_R loss: 0.681598] [G loss: -0.246595]
[Lambda: 0.250000] [Epoch 2400/10000] [D_F loss: 0.687032] [D_R loss: 0.681332] [G loss: -0.245469]
[Lambda: 0.250000] [Epoch 2600/10000] [D_F loss: 0.687178] [D_R loss: 0.681141] [G loss: -0.244499]
[Lambda: 0.250000] [Epoch 2800/10000] [D_F loss: 0.687370] [D_R loss: 0.680988] [G loss: -0.243658]
[Lambda: 0.250000] [Epoch 3000/10000] [D_F loss: 0.687492] [D_R loss: 0.680869] [G loss: -0.242916]
[Lambda: 0.250000] [Epoch 3200/10000] [D_F loss: 0.687491] [D_R loss: 0.680804] [G loss: -0.242256]
[Lambda: 0.250000] [Epoch 3400/10000] [D_F loss: 0.687513] [D_R loss: 0.680760] [G loss: -0.241658]
[Lambda: 0.250000] [Epoch 3600/10000] [D_F loss: 0.687549] [D_R loss: 0.680739] [G loss: -0.241106]
[Lambda: 0.250000] [Epoch 3800/10000] [D_F loss: 0.687515] [D_R loss: 0.680758] [G loss: -0.240610]


[Lambda: 0.300000] [Epoch 8200/10000] [D_F loss: 0.685752] [D_R loss: 0.679776] [G loss: -0.297264]
[Lambda: 0.300000] [Epoch 8400/10000] [D_F loss: 0.685732] [D_R loss: 0.679572] [G loss: -0.297119]
[Lambda: 0.300000] [Epoch 8600/10000] [D_F loss: 0.685692] [D_R loss: 0.679404] [G loss: -0.296986]
[Lambda: 0.300000] [Epoch 8800/10000] [D_F loss: 0.685666] [D_R loss: 0.679241] [G loss: -0.296862]
[Lambda: 0.300000] [Epoch 9000/10000] [D_F loss: 0.685651] [D_R loss: 0.679053] [G loss: -0.296743]
[Lambda: 0.300000] [Epoch 9200/10000] [D_F loss: 0.685660] [D_R loss: 0.678864] [G loss: -0.296615]
[Lambda: 0.300000] [Epoch 9400/10000] [D_F loss: 0.685609] [D_R loss: 0.678713] [G loss: -0.296476]
[Lambda: 0.300000] [Epoch 9600/10000] [D_F loss: 0.685592] [D_R loss: 0.678545] [G loss: -0.296343]
[Lambda: 0.300000] [Epoch 9800/10000] [D_F loss: 0.685596] [D_R loss: 0.678349] [G loss: -0.296204]
Test accuracy: 0.8240000009536743
P(y_hat=1 | z=0) = 0.436, P(y_hat=1 | z=1) = 0.574
P(y_hat=1 | y=1

[Lambda: 0.400000] [Epoch 4000/10000] [D_F loss: 0.687337] [D_R loss: 0.684144] [G loss: -0.429876]
[Lambda: 0.400000] [Epoch 4200/10000] [D_F loss: 0.687282] [D_R loss: 0.684001] [G loss: -0.429401]
[Lambda: 0.400000] [Epoch 4400/10000] [D_F loss: 0.687213] [D_R loss: 0.683840] [G loss: -0.428987]
[Lambda: 0.400000] [Epoch 4600/10000] [D_F loss: 0.687148] [D_R loss: 0.683651] [G loss: -0.428610]
[Lambda: 0.400000] [Epoch 4800/10000] [D_F loss: 0.687099] [D_R loss: 0.683427] [G loss: -0.428265]
[Lambda: 0.400000] [Epoch 5000/10000] [D_F loss: 0.687062] [D_R loss: 0.683200] [G loss: -0.427946]
[Lambda: 0.400000] [Epoch 5200/10000] [D_F loss: 0.687020] [D_R loss: 0.682970] [G loss: -0.427649]
[Lambda: 0.400000] [Epoch 5400/10000] [D_F loss: 0.686962] [D_R loss: 0.682769] [G loss: -0.427381]
[Lambda: 0.400000] [Epoch 5600/10000] [D_F loss: 0.686901] [D_R loss: 0.682546] [G loss: -0.427139]
[Lambda: 0.400000] [Epoch 5800/10000] [D_F loss: 0.686862] [D_R loss: 0.682354] [G loss: -0.426919]


[Lambda: 0.520000] [Epoch 0/10000] [D_F loss: 0.686341] [D_R loss: 0.716233] [G loss: 0.060358]
[Lambda: 0.520000] [Epoch 200/10000] [D_F loss: 0.682480] [D_R loss: 0.705212] [G loss: 0.057829]
[Lambda: 0.520000] [Epoch 400/10000] [D_F loss: 0.679609] [D_R loss: 0.698599] [G loss: 0.057678]
[Lambda: 0.520000] [Epoch 600/10000] [D_F loss: 0.689413] [D_R loss: 0.702650] [G loss: -0.617674]
[Lambda: 0.520000] [Epoch 800/10000] [D_F loss: 0.708826] [D_R loss: 0.711040] [G loss: -0.622047]
[Lambda: 0.520000] [Epoch 1000/10000] [D_F loss: 0.713562] [D_R loss: 0.710284] [G loss: -0.617979]
[Lambda: 0.520000] [Epoch 1200/10000] [D_F loss: 0.707583] [D_R loss: 0.704548] [G loss: -0.609954]
[Lambda: 0.520000] [Epoch 1400/10000] [D_F loss: 0.696940] [D_R loss: 0.697852] [G loss: -0.601033]
[Lambda: 0.520000] [Epoch 1600/10000] [D_F loss: 0.687138] [D_R loss: 0.692088] [G loss: -0.593741]
[Lambda: 0.520000] [Epoch 1800/10000] [D_F loss: 0.681830] [D_R loss: 0.687905] [G loss: -0.589522]
[Lambda: 0

In [7]:
print("-----------------------------------------------------------------------------------")
print("------------------ Training Results of FR-Train on poisoned data ------------------" )
for i in range(len(train_result)):
    print(
        "[Lambda_f: %.2f] [Lambda_r: %.2f] Accuracy : %.3f, Disparate Impact : %.3f "
        % (train_result[i][0][0], train_result[i][0][1], train_result[i][0][2], train_result[i][0][3])
    )       
print("-----------------------------------------------------------------------------------")

-----------------------------------------------------------------------------------
------------------ Training Results of FR-Train on poisoned data ------------------
[Lambda_f: 0.10] [Lambda_r: 0.40] Accuracy : 0.842, Disparate Impact : 0.657 
[Lambda_f: 0.15] [Lambda_r: 0.40] Accuracy : 0.835, Disparate Impact : 0.704 
[Lambda_f: 0.20] [Lambda_r: 0.40] Accuracy : 0.833, Disparate Impact : 0.722 
[Lambda_f: 0.25] [Lambda_r: 0.40] Accuracy : 0.827, Disparate Impact : 0.745 
[Lambda_f: 0.30] [Lambda_r: 0.40] Accuracy : 0.824, Disparate Impact : 0.760 
[Lambda_f: 0.35] [Lambda_r: 0.40] Accuracy : 0.821, Disparate Impact : 0.764 
[Lambda_f: 0.40] [Lambda_r: 0.40] Accuracy : 0.821, Disparate Impact : 0.770 
[Lambda_f: 0.45] [Lambda_r: 0.40] Accuracy : 0.810, Disparate Impact : 0.786 
[Lambda_f: 0.52] [Lambda_r: 0.40] Accuracy : 0.814, Disparate Impact : 0.827 
-----------------------------------------------------------------------------------
