# Read trained FarconVAE models for each dataset and attribute and evalutae performance
This file allows you to read and analyze the train FarconVAE models from the paper. Newly trained models using the code in this repo will be included as well.

In [1]:
import os
import torch
import pandas as pd
from model.farcon import *
from collections import namedtuple
from train.train_farcon import *
from model.farcon import *
from eval import *
from sklearn.metrics import accuracy_score, confusion_matrix
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
eval_model = 'lr'

# Args

In [2]:
# Return model args for data
def get_args(data_name):
    args = namedtuple('args', ['data_name', 'y_dim', 's_dim', 'n_features', 'latent_dim', 'hidden_units', 'sensitive', 'target', 'train_file_name', 'test_file_name', 'encoder', 'batch_size', 
                               'batch_size_te', 'epochs', 'fade_in', 'beta_anneal', 'clf_act', 'clf_seq', 'clf_hidden_units', 'connection', 'enc_act', 'dec_act', 'enc_seq', 'dec_seq', 'pred_act', 'pred_seq',
                                 'model_path', 'data_path', 'result_path', 'clf_path', 'patience', 'kernel', 'drop_p', 'neg_slop', 'clf_layers', 'cont_xs', 'env_flag', 'tr_ratio', 'last_epmod_eval'])
    args.data_name = data_name
    # Defaults
    args.kernel = 'g'
    args.drop_p = 0.3
    args.neg_slop = 0.1
    args.clf_layers = 2
    args.cont_xs = 1
    args.env_flag = 'nn'
    args.tr_ratio = 1.0
    args.last_epmod_eval = 1

    # Per dataset
    if args.data_name == 'adult':
        args.y_dim = 1
        args.s_dim = 1
        args.n_features = 95
        args.latent_dim = 15
        args.hidden_units = 64
        args.sensitive = 'gender_ Male'
        args.target = 'income_ >50K'
        args.train_file_name = 'adult_train_bin.csv'
        args.test_file_name = 'adult_test_bin.csv'
        args.encoder = 'mlp'
        args.batch_size = 30162
        args.batch_size_te = 15060
        args.epochs = 300
        args.fade_in = 1
        args.beta_anneal = 0
        args.clf_act = 'leaky'
        args.clf_seq = 'fad'
        args.clf_hidden_units = 64
        args.connection = 0
        args.enc_act = 'gelu'
        args.dec_act = 'gelu'
        args.enc_seq = 'fa'
        args.dec_seq = 'fa'
        args.pred_act = 'leaky'
        args.pred_seq = 'fba' if args.kernel == 't' else 'fad'
        args.model_path = './model_adult'
        args.data_path = './data/adult/'
        args.result_path = './result_adult'
        args.clf_path = './bestclf/bestclf_adult.pth'
    elif args.data_name == 'german':
        args.y_dim = 1
        args.s_dim = 1
        args.n_features = 45
        args.latent_dim = 5
        args.hidden_units = 64
        args.sensitive = 'gender_ Male'
        args.target = 'risk_Bad'
        args.train_file_name = 'german_train_bin.csv'
        args.test_file_name = 'german_test_bin.csv'
        args.encoder = 'mlp'
        args.batch_size = 800
        args.batch_size_te = 200
        args.epochs = 2000
        args.fade_in = 0
        args.beta_anneal = 1
        args.clf_act = 'leaky'
        args.clf_seq = 'fbad'
        args.clf_hidden_units = 64
        args.connection = 0
        args.enc_act = 'prelu'
        args.dec_act = 'prelu'
        args.enc_seq = 'fa'
        args.dec_seq = 'fa'
        args.pred_act = 'leaky'
        args.pred_seq = 'fba'
        args.model_path = './model_german'
        args.data_path = './data/german/'
        args.result_path = './result_german'
        args.clf_path = './bestclf/bestclf_german.pth'

    # Add adult for age as sensitive
      # Add adult for age as sensitive
    elif args.data_name == 'adult_age':
        args.y_dim = 1
        args.s_dim = 1
        args.n_features = 95
        args.latent_dim = 15
        args.hidden_units = 64
        args.sensitive = 'age'
        args.target = 'income_ >50K'
        args.train_file_name = 'adult_train_bin.csv'
        args.test_file_name = 'adult_test_bin.csv'
        args.encoder = 'mlp'
        args.batch_size = 30162
        args.batch_size_te = 15060
        args.epochs = 300
        args.fade_in = 1
        args.beta_anneal = 0
        args.clf_act = 'leaky'
        args.clf_seq = 'fad'
        args.clf_hidden_units = 64
        args.connection = 0
        args.enc_act = 'gelu'
        args.dec_act = 'gelu'
        args.enc_seq = 'fa'
        args.dec_seq = 'fa'
        args.pred_act = 'leaky'
        args.pred_seq = 'fba' if args.kernel == 't' else 'fad'
        args.model_path = './model_adult_age_no_clf'
        args.data_path = './data/adult/'
        args.result_path = './result_adult_age_no_clf'
        args.clf_path = 'no' #'./bestclf/bestclf_adult.pth'
    # Add SCF, Race
    elif args.data_name == 'SCF_Race':
        args.y_dim = 1
        args.s_dim = 1
        args.n_features = 3531 #******
        args.latent_dim = 120 # Increase
        args.hidden_units = 120
        args.sensitive = 'x6809' # sensitive race
        args.target = 'x3004' # target loan
        args.train_file_name = 'SCF_train.csv'
        args.test_file_name = 'SCF_test.csv'
        args.encoder = 'mlp'
        args.batch_size = 30162
        args.batch_size_te = 15060
        args.epochs = 300
        args.fade_in = 1
        args.beta_anneal = 0
        args.clf_act = 'leaky'
        args.clf_seq = 'fad'
        args.clf_hidden_units = 64
        args.connection = 0
        args.enc_act = 'gelu'
        args.dec_act = 'gelu'
        args.enc_seq = 'fa'
        args.dec_seq = 'fa'
        args.pred_act = 'leaky'
        args.pred_seq = 'fba' if args.kernel == 't' else 'fad'
        args.model_path = './model_scf' #'./model_adult'
        args.data_path = './data/scf/'
        args.result_path = './result_scf_race'
        args.clf_path = 'no' #'./bestclf/bestclf_adult.pth'
    elif args.data_name == 'SCF_Siblings':
        args.y_dim = 1
        args.s_dim = 1
        args.n_features = 3529 #******
        args.latent_dim = 120 # Increase
        args.hidden_units = 120
        args.sensitive = 'x6809' # sensitive race
        args.target = 'x3004' # target loan
        args.train_file_name = 'SCF_Siblings_train.csv'
        args.test_file_name = 'SCF_Siblings_test.csv'
        args.encoder = 'mlp'
        args.batch_size = 30162
        args.batch_size_te = 15060
        args.epochs = 300
        args.fade_in = 1
        args.beta_anneal = 0
        args.clf_act = 'leaky'
        args.clf_seq = 'fad'
        args.clf_hidden_units = 64
        args.connection = 0
        args.enc_act = 'gelu'
        args.dec_act = 'gelu'
        args.enc_seq = 'fa'
        args.dec_seq = 'fa'
        args.pred_act = 'leaky'
        args.pred_seq = 'fba' if args.kernel == 't' else 'fad'
        args.model_path = './model_scf_siblings' #'./model_adult'
        args.data_path = './data/scf/'
        args.result_path = './result_scf_siblings'
        args.clf_path = 'no' #'./bestclf/bestclf_adult.pth'
    
    args.patience = int(args.epochs * 0.10)
    return args
    # -------------------------------------------------------------------

# Baselines

In [3]:
# Data attributes
data_attributes = {'SCF_Siblings': {'data_name': 'SCF', 'sensitive': '# Siblings', 's=0': 'Few (<4)', 's=1': 'Many (4+)', 's=0 parity':0.13523065476190477 , 's=1 parity': 0.21307072515666964, 's=0 proportion': 0.8783356932795992, 's=1 proportion': 0.12166430672040082},
                   'SCF_Race': {'data_name': 'SCF', 'sensitive': 'Race', 's=0': 'White', 's=1': 'Black', 's=0 parity': 0.12142170989433237, 's=1 parity': 0.27702948671277755, 's=0 proportion': 0.8503975601786298, 's=1 proportion':0.14960243982137023},
                     'adult': {'data_name': 'Adult', 'sensitive': 'Sex', 's=0': 'Female', 's=1': 'Male', 's=0 parity':  0.11357604627424293, 's=1 parity': 0.31247747895305794, 's=0 proportion': 0.32495245676882933, 's=1 proportion': 0.6750475432311707},
                     'adult_age': {'data_name': 'Adult', 'sensitive': 'Age', 's=0': 'Young', 's=1': 'Old', 's=0 parity': 0.1421809180527207, 's=1 parity': 0.3574678981752647, 's=0 proportion':  0.4908009375967449, 's=1 proportion': 0.5091990624032551}
                   }


# Load Models

In [4]:
def evaluator_predict(predictor, test_dl, target, device):
    predictor.eval()
    N_test, correct_cnt = 0, 0
    y_pred_raw = torch.tensor([], device=device)
    s_pred_raw = torch.tensor([], device=device)
    with torch.no_grad():
        for feature, s, y in test_dl:
            feature, s, y = feature.to(device), s.to(device), y.to(device)
            prediction_logit = predictor(feature)

            # get accuracy
            N_test += feature.size(0)
            if target == 'y':
                correct_cnt += ((prediction_logit > 0) == y).sum().item()
                y_pred_raw = torch.cat((y_pred_raw, prediction_logit))
            else:
                correct_cnt += ((prediction_logit > 0) == s).sum().item()
                s_pred_raw = torch.cat((s_pred_raw, prediction_logit))

    if target == 'y':
        mask = y_pred_raw < 0
        y_pred = torch.ones_like(y_pred_raw)
        y_pred[mask] = 0
        return (y_pred.cpu()).reshape(-1, 1), correct_cnt / N_test  # y pred, y acc
    else:
        mask = s_pred_raw < 0
        s_pred = torch.ones_like(s_pred_raw)
        s_pred[mask] = 0
        return (s_pred.cpu()).reshape(-1, 1),correct_cnt / N_test  # s acc

In [5]:
# Returns list of files in foler that start with the prefix
def get_files_starting_with(folder_path, prefix):
    files = []
    for file in os.listdir(folder_path):
        if file.startswith(prefix):
            files.append(os.path.join(folder_path, file))
    return files

In [6]:
# Returns the trained farcon models, and s and y classifiers for their downstream representations
def get_models(args):
    data_name = args.data_name
    models = {}


    ids_to_del = []


    model_path = args.model_path #'/model_adult'
    file_start = 'farcon_'
    file_names = get_files_starting_with(model_path, file_start)

    # For each file (vfae model) in the folder, get new representation of test data
    for file in file_names:
        # get id from file name (part of string after the prefix)
        id = file[len(model_path) + len(file_start):]
        id = id[:id.index('.pth')]
        # Get model from file
        model = FarconVAE(args, None)
        state_dict = torch.load(file)
        try:
            model.load_state_dict(state_dict)
            # Add model to id in dictionary
            models[id] = {'farcon':model}
            #print('farocn loaded')
        except:
            print("Error loading farcon model")
            #ids_to_del.append(id)
            continue

        # Get clf model
        clf_xy = BestClf(args.n_features, args.y_dim, args.clf_hidden_units, args)
        # If has saved best clf, use that one
        if args.clf_path != 'no': # Is this applying to adult_age?
            clf_xy.load_state_dict(torch.load(args.clf_path, map_location=device))
        # Otherwise take from same folder as farcon model
        else:
            clf_file= model_path + '/clf' + id + '.pth'
            clf_xy.load_state_dict(torch.load(clf_file, map_location=device))
        models[id]['clf'] = clf_xy

    # Get s classifying models
        

    
    downstream_prefix_s = data_name + '_downstream_s'
    downstream_prefix_y = data_name + '_downstream_y'
    result_model_path = args.result_path#result_path #'/result_adult'
    
    for id in models.keys():
        s_file = result_model_path+'/'+ downstream_prefix_s + id + '.pth.pt' 
        y_file = result_model_path+'/'+ downstream_prefix_y + id + '.pth.pt'
        # if s or y file not found, delete dictionary entry
        if not os.path.isfile(s_file) or not os.path.isfile(y_file):
            ids_to_del.append(id)
            print('missing file for id:', id)
            continue
        args.pred_seq = 'fad' #*********

        clf_y = OneLinearLayer(args.latent_dim, args.y_dim) if eval_model == "lr" else Predictor(args.latent_dim, args.y_dim, args.hidden_units, args)
        clf_s = Predictor(args.latent_dim, args.s_dim, args.hidden_units, args=args)
       
        try:
            #print(torch.load(s_file).keys())
            clf_s.load_state_dict(torch.load(s_file))
        except:
            try:
                # Try with new structure
                args.pred_seq = 'fba'
                clf_s = Predictor(args.latent_dim, args.s_dim, args.hidden_units, args=args)
                clf_s.load_state_dict(torch.load(s_file))
                args.pred_seq = 'fad'
            except:
                # If still fails, delete id
                print("Error loading s model")
                print("Suggested keys: ",torch.load(s_file).keys())
                ids_to_del.append(id)
                continue
        try:
            clf_y.load_state_dict(torch.load(y_file))
        except:
            print("Error loading y model")
            ids_to_del.append(id)
            continue
        models[id]['s'] = clf_s
        models[id]['y'] = clf_y
    print('ids to delete:', ids_to_del)
    for id in list(set(ids_to_del)):
        del models[id]
    return models



In [7]:
# Get latent space representation for each model. Take models dict and data as argument
def get_latent_representations(models: dict, args):
    train_loader, test_loader = get_xsy_loaders(os.path.join(args.data_path, args.train_file_name),
                                                    os.path.join(args.data_path, args.test_file_name),
                                                    args.data_name, args.sensitive, args.batch_size_te, args) # Can do this way or by uncommenting prior lines i think
    farcon_model = models['farcon']
    clf_model = models['clf']
    test_representation, _, _, _ = encode_all(args, test_loader.dataset, farcon_model,clf_model, device=device, is_train=False) #(from eval.py)
    train_representation, _, _, _ = encode_all(args, train_loader.dataset, farcon_model,clf_model, device=device, is_train=True)
    return train_representation, test_representation

In [8]:
# Return y and s predictions for latent representation of test data
def get_preds(cur_models, latent_representation, train_loader, test_loader, args):
    _, te_dl = get_representation_loader(latent_representation, latent_representation, train_loader.dataset, test_loader.dataset, args.batch_size_te)
    y_model = cur_models['y']
    s_model = cur_models['s']
    #s_preds[id], y_preds[id] = get_preds(cur_models, latent_representations)
    y_pred, y_acc = evaluator_predict(y_model, te_dl, 'y', device)
    s_pred, s_acc = evaluator_predict(s_model, te_dl, 's', device)
    return y_pred, y_acc,s_pred, s_acc
    

# Get models for Adult, Sex

In [9]:
def get_accuracies_by_experiment(data_name):
    accuracies = {}
    args = get_args(data_name = data_name)
    models = get_models(args)
    #print('models:', models)
    train_loader, test_loader = get_xsy_loaders(os.path.join(args.data_path, args.train_file_name),
                                                    os.path.join(args.data_path, args.test_file_name),
                                                    args.data_name, args.sensitive, args.batch_size_te, args) # Can do this way or by uncommenting prior lines i think
    for id in models.keys():
        print(id)
        accuracies[id] = {'y_acc': [], 's_acc': [], 'y_pred': [], 'y_true': [], 's_pred': [], 's_true': []} # Fill with results of predictor
        cur_models = models[id]
        train_rep, test_rep = get_latent_representations(cur_models, args)
        y_pred, y_acc,s_pred, s_acc = get_preds(cur_models, test_rep, train_loader, test_loader, args)
        
        #print(y_acc, s_acc)
        # This is wrong **** replaces each time
        accuracies[id]['y_acc'].append(y_acc)
        accuracies[id]['s_acc'].append(s_acc)
        accuracies[id]['y_pred'].append(y_pred)
        accuracies[id]['s_pred'].append(s_pred)
        accuracies[id]['y_true'].append(test_loader.dataset.y)
        accuracies[id]['s_true'].append(test_loader.dataset.s)
    
    return accuracies


def get_s_baselines_by_experiment(data_name):
    args = get_args(data_name = data_name)
    


# Enter accuracies into dictionary

In [10]:
#accuracies = {'Adult_Sex':{}, 'Adult_Age': {}, 'SCF_Race': {}}
accuracies = {}
# Get accuracies adult sex
data_name = 'adult'
results = get_accuracies_by_experiment(data_name=data_name)
accuracies[data_name] = results

data_name = 'adult_age'
results = get_accuracies_by_experiment(data_name=data_name)
accuracies[data_name] = results

data_name = 'SCF_Race'
results = get_accuracies_by_experiment(data_name=data_name)
accuracies[data_name] = results


data_name = 'SCF_Siblings'
results = get_accuracies_by_experiment(data_name=data_name)
accuracies[data_name] = results


missing file for id: _2024_03_05_17_30_18_369778_ours
ids to delete: ['_2024_03_05_17_30_18_369778_ours']
_2024_03_06_01_40_56_403815_ours
_2024_03_05_22_46_11_320915_ours
_2024_03_05_22_09_40_994430_ours
_2024_03_05_17_58_04_316124_ours
_2024_03_05_20_10_17_982934_ours
_2024_03_05_19_14_29_896811_ours
_2024_03_05_19_39_34_253995_ours
_2024_03_05_21_06_17_486998_ours
_2024_03_05_23_45_55_384556_ours
_2024_03_05_18_49_10_855960_ours
_2024_03_06_03_49_40_369269_ours
_2024_03_05_18_23_49_638763_ours
_2024_03_05_20_39_06_565216_ours
_2024_03_06_03_14_38_296293_ours
_2024_03_06_02_14_03_781213_ours
_2024_03_05_17_32_16_991758_ours
_2024_03_06_01_05_07_432427_ours
_2024_03_06_02_47_18_326329_ours
_2024_03_06_04_21_00_427210_ours
_2024_03_05_23_14_43_338087_ours
ids to delete: []
_2024_03_19_14_11_57_463850_ours
_2024_03_19_16_17_03_961684_ours
_2024_03_19_06_03_38_440582_ours
_2024_03_19_12_02_17_222115_ours
_2024_03_18_14_44_59_266007_ours
_2024_03_19_18_07_12_463912_ours
_2024_03_19_15_14_

# Compute Fairness Metrics

In [11]:
def get_fairness_metrics(y_pred, y_true, s_true, s_pred = None):
    metrics = {}
    # Get  overall accuracy
    metrics['accuracy'] = accuracy_score(y_true, y_pred)
    # Get accuracy by group
    metrics['accuracy_sensitive'] = accuracy_score(y_true[s_true == 1], y_pred[s_true == 1])
    metrics['accuracy_not_sensitive'] = accuracy_score(y_true[s_true == 0], y_pred[s_true == 0])

    # # Get s accuracy overall and by group
    if s_pred is not None:
        # ***************where s problem is*************
        metrics['s_accuracy'] = accuracy_score(s_true, s_pred)
        metrics['s_accuracy_sensitive'] = accuracy_score(s_true[s_true == 1], s_pred[s_true == 1])
        metrics['s_accuracy_not_sensitive'] = accuracy_score(s_true[s_true == 0], s_pred[s_true == 0])

    # Get overall confusion matrix
    conf_matrix = confusion_matrix(y_true, y_pred)
    fpr = conf_matrix[0][1] / (conf_matrix[0][1] + conf_matrix[0][0])
    fnr = conf_matrix[1][0] / (conf_matrix[1][0] + conf_matrix[1][1])
    # Get confusion matrix for each s value
    # Seperate y_pred and y_true by where s_true is 1 and 0
    y_pred_sensitive = y_pred[s_true == 1]
    y_true_sensitive = y_true[s_true == 1]
    y_pred_not_sensitive = y_pred[s_true == 0]
    y_true_not_sensitive = y_true[s_true == 0]
    conf_matrix_sensitive = confusion_matrix(y_true_sensitive, y_pred_sensitive)
    conf_matrix_not_sensitive = confusion_matrix(y_true_not_sensitive, y_pred_not_sensitive)
    fpr_sensitive = conf_matrix_sensitive[0][1] / (conf_matrix_sensitive[0][1] + conf_matrix_sensitive[0][0])
    fnr_sensitive = conf_matrix_sensitive[1][0] / (conf_matrix_sensitive[1][0] + conf_matrix_sensitive[1][1])
    fpr_not_sensitive = conf_matrix_not_sensitive[0][1] / (conf_matrix_not_sensitive[0][1] + conf_matrix_not_sensitive[0][0])
    fnr_not_sensitive = conf_matrix_not_sensitive[1][0] / (conf_matrix_not_sensitive[1][0] + conf_matrix_not_sensitive[1][1])

    # Calculate positive rates

    positive_rate_sensitive = (conf_matrix_sensitive[1][1] + conf_matrix_sensitive[0][1]) / len(y_true_sensitive)
    positive_rate_not_sensitive = (conf_matrix_not_sensitive[1][1] + conf_matrix_not_sensitive[0][1]) / len(y_true_not_sensitive)
    positive_rate = (conf_matrix[1][1] + conf_matrix[0][1]) / len(y_true)
    true_pos_rate_not_sensitive = (conf_matrix_not_sensitive[1][1]+conf_matrix_not_sensitive[1][0]) / len(y_true_not_sensitive)
    true_pos_rate_sensitive = (conf_matrix_sensitive[1][1]+conf_matrix_sensitive[1][0]) / len(y_true_sensitive)

    # Add metrics to dict
    metrics['fpr'] = fpr
    metrics['fnr'] = fnr
    metrics['fpr_sensitive'] = fpr_sensitive
    metrics['fnr_sensitive'] = fnr_sensitive
    metrics['fpr_not_sensitive'] = fpr_not_sensitive
    metrics['fnr_not_sensitive'] = fnr_not_sensitive
    metrics['positive_rate_sensitive'] = positive_rate_sensitive
    metrics['positive_rate_not_sensitive'] = positive_rate_not_sensitive
    metrics['positive_rate'] = positive_rate
    metrics['true_pos_rate_sensitive'] = true_pos_rate_sensitive
    metrics['true_pos_rate_not_sensitive'] = true_pos_rate_not_sensitive

    return metrics

In [12]:
# Write all of the metrics to one file for each dataset. Write to fairness_metrics folder
def write_fairness(data_name, metrics_list):
    file_path = './fairness_results/' + data_name
        #file_name = data_name + '_fairness_metrics.txt'
    with open(file_path, 'w') as f:
        for metrics in metrics_list:
            for metric in metrics.keys():
                f.write(metric + ':' + str(metrics[metric]) + '\n')
                


In [13]:
def get_fairness_by_name(data_name, accuracies):
    args = get_args(data_name)
    train_loader, test_loader = get_xsy_loaders(os.path.join(args.data_path, args.train_file_name),
                                                    os.path.join(args.data_path, args.test_file_name),
                                                    args.data_name, args.sensitive, args.batch_size_te, args)
    metrics_list = []
    for id in accuracies[data_name].keys():
        y_pred = accuracies[data_name][id]['y_pred'][0]
        y_true = accuracies[data_name][id]['y_true'][0]
        s_true = test_loader.dataset.s
        s_pred = accuracies[data_name][id]['s_pred'][0]
        metrics = get_fairness_metrics(y_pred, y_true, s_true, s_pred=s_pred)
        metrics_list.append(metrics)
    return metrics_list

In [14]:
adult_metrics = get_fairness_by_name('adult', accuracies)
adult_age_metrics = get_fairness_by_name('adult_age', accuracies)
scf_metrics = get_fairness_by_name('SCF_Race', accuracies)
scf_siblings_metrics = get_fairness_by_name('SCF_Siblings', accuracies)

In [15]:
def average_metrics(metrics_list):
    avg_metrics = {}
    for metric in metrics_list[0].keys():
        avg_metrics[metric] = sum([metrics[metric] for metrics in metrics_list]) / len(metrics_list)
    return avg_metrics


In [16]:
avg_adult_metrics = average_metrics(adult_metrics)
avg_adult_age_metrics = average_metrics(adult_age_metrics)
avg_scf_metrics = average_metrics(scf_metrics)
avg_scf_siblings_metrics = average_metrics(scf_siblings_metrics)