In [1]:
%pwd

'/Users/ryandevera/data-science/umn_environments/Constrained-Deep-Learning-Survey'

In [2]:
# stdlib
import math
import random

# third party
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow.compat.v1 as tf
import tensorflow_constrained_optimization as tfco
import warnings

# first party
from cdlsurvey.data import get_data
from cdlsurvey.metrics import get_exp_error_rate_constraints
from cdlsurvey.models import Model
from cdlsurvey.utils import training_helper

# Disable eager execution
tf.disable_eager_execution()

# suppress warnings
warnings.filterwarnings('ignore')

# For plotting in notebook
%matplotlib inline

2024-01-18 14:12:09.770822: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
PROTECTED_COLUMNS = ['gender_Female', 'gender_Male', 'race_White', 'race_Black']

In [4]:
CATEGORICAL_COLUMNS = [
    'workclass', 'education', 'marital_status', 'occupation', 'relationship',
    'race', 'gender', 'native_country'
]
CONTINUOUS_COLUMNS = [
    'age', 'capital_gain', 'capital_loss', 'hours_per_week', 'education_num'
]
COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]
LABEL_COLUMN = 'label'

PROTECTED_COLUMNS = [
    'gender_Female', 'gender_Male', 'race_White', 'race_Black'
]

In [5]:
# Gather the data
train_df, test_df, FEATURE_NAMES = get_data()

In [6]:
# Make sure there are positive labels
train_df['label'].sum(), test_df['label'].sum()

(7841, 3846)

In [7]:
model = Model(
    tpr_max_diff=0.05,
    protected_columns=PROTECTED_COLUMNS,
    feature_names=FEATURE_NAMES,
    label_column=LABEL_COLUMN,
)
model.build_train_op(0.01, unconstrained=True)

# training_helper returns the list of errors and violations over each epoch.
train_errors, train_violations, test_errors, test_violations = training_helper(
    model,
    train_df,
    test_df,
    100,
    num_iterations_per_loop=326,
    num_loops=40,
)

In [8]:
print("Train Error", train_errors[-1])
print("Train Violation", max(train_violations[-1]))
print()
print("Test Error", test_errors[-1])
print("Test Violation", max(test_violations[-1]))

Train Error 0.1428088817910998
Train Violation 0.02899853091646859

Test Error 0.14341870892451325
Test Violation 0.05963578207932785


In [None]:
model = Model(
    tpr_max_diff=0.05,
    protected_columns=PROTECTED_COLUMNS,
    feature_names=FEATURE_NAMES,
    label_column=LABEL_COLUMN,
)
model.build_train_op(0.01, unconstrained=False)

# training_helper returns the list of errors and violations over each epoch.
train_errors, train_violations, test_errors, test_violations = training_helper(
      model,
      train_df,
      test_df,
      100,
      num_iterations_per_loop=326,
      num_loops=40)

In [None]:
print("Train Error", train_errors[-1])
print("Train Violation", max(train_violations[-1]))
print()
print("Test Error", test_errors[-1])
print("Test Violation", max(test_violations[-1]))

In [None]:
print("Train Error", np.mean(train_errors))
print("Train Violation", max(np.mean(train_violations, axis=0)))
print()
print("Test Error", np.mean(test_errors))
print("Test Violation", max(np.mean(test_violations, axis=0)))

In [None]:
cand_dist = tfco.find_best_candidate_distribution(train_errors, train_violations)
print(cand_dist)

In [None]:
m_stoch_error_train, m_stoch_violations_train = get_exp_error_rate_constraints(cand_dist, train_errors, train_violations)
m_stoch_error_test, m_stoch_violations_test = get_exp_error_rate_constraints(cand_dist, test_errors, test_violations)

print("Train Error", m_stoch_error_train)
print("Train Violation", max(m_stoch_violations_train))
print()
print("Test Error", m_stoch_error_test)
print("Test Violation", max(m_stoch_violations_test))