In [1]:
import sys
src_path =   '../src' # change as needed
sys.path.insert(0,src_path)

import numpy as np
import data_generator
import model_utils

INFO:absl:Using /tmp/tfhub_modules to cache modules.


In [2]:
#Start session
max_length = 128

train_data, val_data, test_data = data_generator.GetData(max_length, sample=1000)

100%|██████████| 3000/3000 [00:00<00:00, 3007.87it/s]


        tag  cat  occurences
0    B-MISC    0           8
1     I-LOC    1        2563
2    I-MISC    2        1106
3     I-ORG    3        1871
4     I-PER    4        8216
5         O    5       53719
6  [nerCLS]    6        3000
7  [nerPAD]    7      293949
8  [nerSEP]    8        3000
9    [nerX]    9       16568

                tag  cat  occurences
0  AFRICAN-AMERICAN    0        3382
1          EUROPEAN    1        1555
2         [raceCLS]    2        3000
3         [racePAD]    3      293949
4         [raceSEP]    4        3000
5           [raceX]    5       79114

           tag  cat  occurences
0       FEMALE    0        2855
1         MALE    1        2082
2  [genderCLS]    2        3000
3  [genderPAD]    3      293949
4  [genderSEP]    4        3000
5    [genderX]    5       79114



In [3]:
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

def getDebiasedModel(max_input_length, train_layers):
    
    in_id = tf.keras.layers.Input(shape=(max_input_length,), name="input_ids")
    in_mask = tf.keras.layers.Input(shape=(max_input_length,), name="input_masks")
    in_segment = tf.keras.layers.Input(shape=(max_input_length,), name="segment_ids")

    bert_inputs = [in_id, in_mask, in_segment]

    bert_sequence = model_utils.BertLayer(n_fine_tune_layers=train_layers)(bert_inputs)

    dense = tf.keras.layers.Dense(256, activation='relu', name='pred_dense')(bert_sequence)

    dense = tf.keras.layers.Dropout(rate=0.1)(dense)

    pred = tf.keras.layers.Dense(10, activation='softmax', name='ner')(dense)
    
    genderPred = tf.keras.layers.Dense(6, activation='softmax', name='gender')(pred)

    racePred = tf.keras.layers.Dense(6, activation='softmax', name='race')(pred)
    
    model = tf.keras.models.Model(inputs=bert_inputs, outputs={
        "ner": pred,
        "race": racePred,
        "gender": genderPred
    })
    
    model.summary()
    
    return model

In [4]:
def random_batch(data, batch_size=32):
    idx = np.random.randint(len(data["nerLabels"]), size=batch_size)
    return [
        data["inputs"][0][idx], 
        data["inputs"][1][idx], 
        data["inputs"][2][idx], 
        data["nerLabels"][idx],
        data["genderLabels"][idx],
        data["raceLabels"][idx]
    ]

In [5]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

model = getDebiasedModel(max_length, 1)

pred_learning_rate = 2**-16
protect_learning_rate = 2**-16
num_epochs = 5

num_train_samples = len(train_data["nerLabels"])

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_masks (InputLayer)        [(None, 128)]        0                                            
__________________________________________________________________________________________________
segment_ids (InputLayer)        [(None, 128)]        0                                            
__________________________________________________________________________________________________
bert_layer (BertLayer)          (None, None, 768)    108931396   input_ids[0][0]                  
                                                                 input_masks[0][0]            

In [6]:
protect_loss_weight = 0.1

In [7]:
def tf_normalize(x):
    """Returns the input vector, normalized.

    A small number is added to the norm so that this function does not break when
    dealing with the zero vector (e.g. if the weights are zero-initialized).

    Args:
    x: the tensor to normalize
    """
    return x / (tf.norm(x) + np.finfo(np.float32).tiny)

In [8]:
ids_ph = tf.placeholder(tf.float32, shape=[32,128])
masks_ph = tf.placeholder(tf.float32, shape=[32,128])
sentenceIds_ph = tf.placeholder(tf.float32, shape=[32,128])

gender_ph = tf.placeholder(tf.float32, shape=[32,128])
ner_labels_ph = tf.placeholder(tf.float32, shape=[32,128])

global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.001
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.96, staircase=True)

protect_vars = [var for var in tf.trainable_variables() if 'gender' in var.name]
pred_vars = model.layers[3]._trainable_weights + [var for var in tf.trainable_variables() if any(x in var.name for x in ["pred_dense","ner"])]
    
y_pred = model([ids_ph, masks_ph, sentenceIds_ph], training=True)

ner_loss = model_utils.custom_loss(ner_labels_ph, y_pred["ner"])
gender_loss = model_utils.custom_loss_protected(gender_ph, y_pred["gender"])

protect_opt = tf.train.AdamOptimizer(protect_learning_rate)
pred_opt = tf.train.AdamOptimizer(pred_learning_rate)

protect_grads = {var: grad for (grad, var) in protect_opt.compute_gradients(gender_loss,var_list=pred_vars)}
pred_grads = []

for (grad, var) in pred_opt.compute_gradients(ner_loss, var_list=pred_vars):
    unit_protect = tf_normalize(protect_grads[var])
    # the two lines below can be commented out to train without debiasing
    grad -= tf.reduce_sum(grad * unit_protect) * unit_protect
    grad -= tf.math.scalar_mul(protect_loss_weight, protect_grads[var])
    pred_grads.append((grad, var))

pred_min = pred_opt.apply_gradients(pred_grads, global_step=global_step)
protect_min = protect_opt.minimize(gender_loss, var_list=[protect_vars], global_step=global_step)

model_utils.initialize_vars(sess)

# Begin training
for epoch in range(num_epochs):
    
    shuffled_ids = np.random.choice(num_train_samples, num_train_samples)

    for i in range(num_train_samples//32):

        ids, masks, sentence_ids, ner_labels, gender_labels, race_labels = random_batch(train_data)

        batch_feed_dict = {ids_ph: ids, 
                           masks_ph: masks,
                           sentenceIds_ph: sentence_ids,
                           gender_ph: gender_labels,
                           ner_labels_ph: race_labels}


        _, _, pred_labels_loss_value, pred_protected_attributes_loss_vale = sess.run([
            pred_min,
            protect_min,
            ner_loss,
            gender_loss
        ], feed_dict=batch_feed_dict)

        #if i % 200 == 0:
        print("epoch %d; iter: %d; batch classifier loss: %f; batch adversarial loss: %f" % (epoch, i, pred_labels_loss_value,
                                                                 pred_protected_attributes_loss_vale))


epoch 0; iter: 0; batch classifier loss: 2.678376; batch adversarial loss: 3.734007
epoch 0; iter: 1; batch classifier loss: 2.585400; batch adversarial loss: 3.721679
epoch 0; iter: 2; batch classifier loss: 2.546495; batch adversarial loss: 3.730783
epoch 0; iter: 3; batch classifier loss: 2.499578; batch adversarial loss: 3.728687
epoch 0; iter: 4; batch classifier loss: 2.417089; batch adversarial loss: 3.737558
epoch 0; iter: 5; batch classifier loss: 2.289476; batch adversarial loss: 3.731979
epoch 0; iter: 6; batch classifier loss: 2.252623; batch adversarial loss: 3.734449
epoch 0; iter: 7; batch classifier loss: 2.161414; batch adversarial loss: 3.727175
epoch 0; iter: 8; batch classifier loss: 2.166079; batch adversarial loss: 3.722918
epoch 0; iter: 9; batch classifier loss: 2.101038; batch adversarial loss: 3.724814
epoch 0; iter: 10; batch classifier loss: 2.035396; batch adversarial loss: 3.725134
epoch 0; iter: 11; batch classifier loss: 1.998918; batch adversarial loss:

epoch 3; iter: 4; batch classifier loss: 0.306072; batch adversarial loss: 3.728918
epoch 3; iter: 5; batch classifier loss: 0.307814; batch adversarial loss: 3.742266
epoch 3; iter: 6; batch classifier loss: 0.313967; batch adversarial loss: 3.729022
epoch 3; iter: 7; batch classifier loss: 0.299893; batch adversarial loss: 3.728607
epoch 3; iter: 8; batch classifier loss: 0.287814; batch adversarial loss: 3.718416
epoch 3; iter: 9; batch classifier loss: 0.287074; batch adversarial loss: 3.725401
epoch 3; iter: 10; batch classifier loss: 0.279586; batch adversarial loss: 3.746905
epoch 3; iter: 11; batch classifier loss: 0.273372; batch adversarial loss: 3.758242
epoch 3; iter: 12; batch classifier loss: 0.268778; batch adversarial loss: 3.740175
epoch 3; iter: 13; batch classifier loss: 0.271226; batch adversarial loss: 3.721240
epoch 3; iter: 14; batch classifier loss: 0.253401; batch adversarial loss: 3.761651
epoch 3; iter: 15; batch classifier loss: 0.261009; batch adversarial l