In [1]:
import os
os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1" 
base_path = os.getcwd() + '/data/'
import torch
import random

import get_dataset
import numpy as np
from FairICP import utility_functions
from FairICP import FairICP_learning
import warnings
warnings.filterwarnings('ignore')


In [2]:
# load R
os.environ['R_HOME'] = r"user\R\R-4.3.0"
os.environ['R_USER'] = r"user\anaconda3\Lib\site-packages\rpy2"

from rpy2.robjects.packages import importr
KPC = importr('KPC')
kernlab = importr('kernlab')
import rpy2.robjects
from rpy2.robjects import FloatVector

In [3]:
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

def mix_gamma(n):
    # Settings
    NumberOfMixtures = 2
    # Mixture weights (non-negative, sum to 1)
    w = [0.5, 0.5]
    # Mean vectors and covariance matrices
    shapeVectors = np.array([1, 10])
    scaleVectors = np.array([1, 1])
    MeanVectors = shapeVectors * scaleVectors
    StdVectors = shapeVectors * np.square(scaleVectors)
    moments = np.square(MeanVectors) + StdVectors
    mean = np.array(w).dot(MeanVectors)
    std = np.sqrt(np.array(w).dot(moments) - np.square(mean))
    # Initialize arrays
    samples = np.zeros(n)
    # Generate samples
    for iter in range(n):
        # Get random number to select the mixture component with probability according to mixture weights
        DrawComponent = random.choices(range(NumberOfMixtures), weights=w, cum_weights=None, k=1)[0]
        # Draw sample from selected mixture component
        DrawSample = np.random.gamma(shape = shapeVectors[DrawComponent], scale = scaleVectors[DrawComponent], size = 1)

        DrawSample = (DrawSample - mean) / std
        samples[iter] = DrawSample
    return samples
    
def synthetic_example_md(dim_insnst = 5, dim_snst = 1, dim_noisy_a = 0, alpha = 0.5, eps = 1, n = 1000, include_A = False):
    # insensitive X
    cov_mat = np.full((dim_insnst, dim_insnst), 0)
    np.fill_diagonal(cov_mat, 1)
    X_insnst = mvn.rvs(mean = [0. for i in range(dim_insnst)], cov=cov_mat, size = n)
    if len(X_insnst.shape) == 1: X_insnst = X_insnst[:,None] 

    # sensitive X
    X_snst = np.array([[] for i in range(n)])
    A = np.array([[] for i in range(n)])
    for i in range(dim_snst):
        A_temp = mix_gamma(n)
        X_temp = np.sqrt(alpha) * A_temp + np.sqrt(1 - alpha) * np.random.randn(n)
        A = np.concatenate([A, A_temp[:,None]], axis = 1)
        X_snst = np.concatenate([X_snst, X_temp[:,None]], axis = 1)

    # additional A
    for i in range(dim_noisy_a):
        A_temp = np.random.randn(n)
        A = np.concatenate([A, A_temp[:,None]], axis = 1)

    X = np.concatenate([X_insnst, X_snst], axis = 1)
    beta = [1] * dim_insnst + [1] * dim_snst
    Y = np.dot(X, beta) + eps * np.random.randn(n)

    if include_A:
        X = np.concatenate((X, A), axis = 1)

    return X, A, Y

def known_prob(Y, A, dim_insnst = 5, alpha = 0.5, eps = 1):
    cov_mat = np.full((dim_insnst, dim_insnst), 0)
    np.fill_diagonal(cov_mat, 1)
    sig_insnst = np.ones(dim_insnst).dot(cov_mat.dot(np.ones(dim_insnst)[:,None]))

    sig_snst = A.shape[1] * (1 - alpha)

    sig2 = np.full((A.shape[0], ), sig_insnst + sig_snst + np.power(eps, 2))
    mu = np.sqrt(alpha) * np.sum(A, axis = 1)

    return - np.power(Y,2)[:,None] * (1/2/sig2)[None,:] + Y[:,None] * (mu/sig2)[None,:]


In [4]:
# load data
simulation_type = 1 # 1/2 for SIM1/SIM2 in the paper 

if simulation_type == 1:
    dim_snst = 10
    dim_insnst = dim_snst
    dim_noisy_a = 0
elif simulation_type == 2:
    dim_snst = 1
    dim_insnst = dim_snst
    dim_noisy_a = 10

alpha = 0.9
eps = np.sqrt(dim_insnst + dim_snst)

X, A, Y = synthetic_example_md(dim_insnst = dim_insnst, dim_snst = dim_snst, dim_noisy_a = dim_noisy_a, alpha = alpha, eps = eps, n = 500)
X_test, A_test, Y_test = synthetic_example_md(dim_insnst = dim_insnst, dim_snst = dim_snst, dim_noisy_a = dim_noisy_a, alpha = alpha, eps = eps, n = 400)
input_data_train = np.concatenate((A, X), 1)
input_data_test = np.concatenate((A_test, X_test), 1)

In [5]:
batch_size = 16
lr_loss = 1e-3
lr_dis = 1e-4

# equalized odds penalty
mu_val = 0.9
epochs_list = [140]

# utility loss
cost_pred = torch.nn.MSELoss()

In [6]:
model = FairICP_learning.EquiRegLearner(lr_loss = lr_loss,
                                            lr_dis = lr_dis,
                                            epochs = epochs_list[-1],
                                            loss_steps = 1,
                                            dis_steps = 1,
                                            cost_pred = cost_pred,
                                            in_shape = X.shape[1],
                                            batch_size = batch_size,
                                            model_type = "linear_model",
                                            lambda_vec = mu_val,
                                            out_shape = 1,
                                            A_shape = A.shape[1]
                                            )
model.fit(input_data_train, Y, epochs_list = epochs_list)

Epoch 58: early stopping


In [7]:
# generate \tilde A from true permutation
log_lik_mat = known_prob(Y_test, A_test[:,:dim_snst], dim_insnst, alpha, eps)

y_perm_index = np.squeeze(utility_functions.generate_X_CPT(50, 100, log_lik_mat))
A_perm_index = np.argsort(y_perm_index)
A_tilde_list = A_test[A_perm_index]

In [8]:
for i, cp in enumerate(model.checkpoint_list):
    model.model = model.cp_model_list[i]
    model.dis = model.cp_dis_list[i]

    Yhat_out_train = model.predict(input_data_train)
    Yhat_out_test = model.predict(input_data_test)

    mse_trivial = np.mean((np.mean(Y_test)-Y_test)**2)
    mse_model = np.mean((Yhat_out_test-Y_test)**2)
    print(f"mse_trivial: {mse_trivial}")
    print(f"mse_model: {mse_model}")

    rYhat = FloatVector(Yhat_out_test)  
    rZ = rpy2.robjects.r.matrix(FloatVector(A_test[:,:dim_snst].T.flatten()), nrow=A_test[:,:dim_snst].shape[0], ncol=A_test[:,:dim_snst].shape[1]) 
    rY = rpy2.robjects.r.matrix(FloatVector(Y_test), nrow=A_test[:,:dim_snst].shape[0], ncol=1)
    
    stat = KPC.KPCgraph 
    res_ = stat(Y = rYhat, X = rY, Z = rZ, Knn = "MST")[0]
    print(f"estimated KPC: {res_}")
    res_list = np.zeros(100)
    for i in range(100):
        At_test = A_tilde_list[i]
        rZt = rpy2.robjects.r.matrix(FloatVector(At_test[:,:dim_snst].T.flatten()), nrow=A_test[:,:dim_snst].shape[0], ncol=A_test[:,:dim_snst].shape[1])
        res_list[i] = stat(Y = rYhat, X = rY, Z = rZt, Knn = "MST")[0]
    p_val = 1.0/(100+1) * (1 + sum(res_list >= res_))
    print(f"p-value: {p_val}")

mse_trivial: 41.44066956537443
mse_model: 20.690114613178263
estimated KPC: -0.01357897111950448
p-value: 0.31683168316831684
