In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.insert(1, '../')
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from scipy.sparse import hstack
from tqdm import tqdm
import scipy.sparse
from scipy.stats import norm
import matplotlib.pyplot as plt
import os, sys
import folktables
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder

from ols_utils import get_data, transform_features

## Implementation of classical inference, PPI with data splitting, and cross-prediction

In [None]:
def trial(income_features, income, n, alpha, enc):
    # one trial; randomly splits data into labeled and unlabeled data and then runs baselines
    
    income_features_labeled, income_features_unlabeled, income_labeled, income_unlabeled = train_test_split(income_features, income, train_size=n)
    
    age_labeled = income_features_labeled['AGEP'].to_numpy()
    age_unlabeled = income_features_unlabeled['AGEP'].to_numpy()
    sex_labeled = income_features_labeled['SEX'].to_numpy()
    sex_unlabeled = income_features_unlabeled['SEX'].to_numpy()
    
    income_labeled = income_labeled.to_numpy()
    
    income_features_labeled = transform_features(income_features_labeled, ft, enc)[0]
    income_features_unlabeled = transform_features(income_features_unlabeled, ft, enc)[0]

    ols_features_labeled = np.stack([age_labeled, sex_labeled], axis=1)
    ols_features_unlabeled = np.stack([age_unlabeled, sex_unlabeled], axis=1)

    classical_interval = classical_ols_interval(ols_features_labeled, income_labeled, alpha)
    
    pp_interval = pp_ols_interval(ols_features_labeled, income_features_labeled, ols_features_unlabeled, income_features_unlabeled, income_labeled, alpha)

    cpp_interval = cross_prediction_ols_interval(ols_features_labeled, income_features_labeled, ols_features_unlabeled, income_features_unlabeled, income_labeled, alpha)

    return classical_interval, pp_interval, cpp_interval

In [None]:
def ols(features, outcome):
    ols_coeffs = np.linalg.pinv(features).dot(outcome)
    return ols_coeffs

In [None]:
def pp_ols_interval(Xols_labeled, X_labeled, Xols_unlabeled, X_unlabeled, Y_labeled, alpha):
    # performs data splitting and then runs PPI
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    n_tr = int(0.2*n)
    
    X_train, X_val, Xols_train, Xols_val, Y_train, Y_val = train_test_split(X_labeled, Xols_labeled, Y_labeled, train_size=n_tr)
    
    X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.2)
    dtrain = xgb.DMatrix(X_train1, label=y_train1)
    dtest = xgb.DMatrix(X_train2, label=y_train2)
    
    evallist = [(dtest, 'eval'), (dtrain, 'train')]
    tree = xgb.train({'eta': 0.3, 'max_depth': 7, 'objective': 'reg:pseudohubererror'}, dtrain, 10000, evallist, verbose_eval=False)
    
    Yhat_unlabeled = tree.predict(xgb.DMatrix(X_unlabeled))
    Yhat_val = tree.predict(xgb.DMatrix(X_val))
    
    thetaPP = rectified_ols_point_estimate(Xols_val, Xols_unlabeled, Y_val, Yhat_val, Yhat_unlabeled)

    
    Hessian = 1/N * Xols_unlabeled.T @ Xols_unlabeled
    inv_Hessian = np.matrix(np.linalg.inv(Hessian))
    
    grads_til = np.zeros(Xols_unlabeled.shape)

    for i in range(N):
        grads_til[i,:] = (np.dot(Xols_unlabeled[i,:], thetaPP) - Yhat_unlabeled[i]) * Xols_unlabeled[i,:]
    var_unlabeled = np.matrix(np.cov(grads_til.T))
    
    
    pred_error = Yhat_val - Y_val
    grad_diff = np.diag(pred_error) @ Xols_val
    var_labeled = np.matrix(np.cov(grad_diff.T))

    Sigma_hat = inv_Hessian @ ((n - n_tr)/N * var_unlabeled + var_labeled) @ inv_Hessian
    
    
    halfwidth = norm.ppf(1-alpha/2) * np.sqrt(np.diag(Sigma_hat)/(n - n_tr))
    
    return [thetaPP - halfwidth, thetaPP + halfwidth]

In [None]:
def rectified_ols_point_estimate(X_labeled, X_unlabeled, Y_labeled, Yhat_labeled, Yhat_unlabeled):
    # computes PPI point estimate; same subroutine is used for cross-prediction
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    
    bias = 1/n * X_labeled.T @ (Yhat_labeled - Y_labeled)
    XTX_inv = np.linalg.inv(1/N * X_unlabeled.T @ X_unlabeled)
    XTy = 1/N * X_unlabeled.T @ Yhat_unlabeled

    return XTX_inv @ (XTy - bias)

In [None]:
def cross_prediction_ols_interval(Xols_labeled, X_labeled, Xols_unlabeled, X_unlabeled, Y_labeled, alpha, K = 5):
    # cross-prediction
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    fold_n = int(n/K)
    
    Yhat_labeled = np.zeros(n)
    Yhat_unlabeled = np.zeros(N)
    Yhat_avg_labeled = np.zeros(n)
    
    for j in range(K):
    
        X_val = X_labeled[j*fold_n:(j+1)*fold_n,:]
        Y_val = Y_labeled[j*fold_n:(j+1)*fold_n]
        train_ind = np.delete(range(n),range(j*fold_n,(j+1)*fold_n))
        X_train = X_labeled[train_ind,:]
        Y_train = Y_labeled[train_ind]

        # use train data to train a tree
        X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.2)
        dtrain = xgb.DMatrix(X_train1, label=y_train1)
        dtest = xgb.DMatrix(X_train2, label=y_train2)
        evallist = [(dtest, 'eval'), (dtrain, 'train')]
        tree = xgb.train({'eta': 0.3, 'max_depth': 7, 'objective': 'reg:pseudohubererror'}, dtrain, 10000, evallist, verbose_eval=False)
        print('cross-fit tree trained')

        Yhat_unlabeled += (tree.predict(xgb.DMatrix(X_unlabeled)))/K
        Yhat_labeled[j*fold_n:(j+1)*fold_n] = tree.predict(xgb.DMatrix(X_val))
    
    

    thetaPP = rectified_ols_point_estimate(Xols_labeled, Xols_unlabeled, Y_labeled, Yhat_labeled, Yhat_unlabeled)


    Sigma_hat = bootstrap_covariance(Xols_labeled, X_labeled, Xols_unlabeled, X_unlabeled, Y_labeled, n - fold_n, thetaPP)
    
    
    halfwidth = norm.ppf(1-alpha/2) * np.sqrt(np.diag(Sigma_hat)/n)
    
    return [thetaPP - halfwidth, thetaPP + halfwidth]

In [None]:
def bootstrap_covariance(Xols_labeled, X_labeled, Xols_unlabeled, X_unlabeled, Y_labeled, n_train, thetaPP, B = 10):
    # estimates the asymptotic variance of cross-fitting
    
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]

    Yhat_labeled = np.zeros(n)
    Yhat_unlabeled = np.zeros(N)
    
    d_inf = len(thetaPP)
    
    grad_diff = np.zeros((int((n-n_train)*B), d_inf))
    
    for j in range(B):
        print(j)
        
        train_ind = np.random.choice(range(n),n_train)
        X_train = X_labeled[train_ind,:]
        Y_train = Y_labeled[train_ind]
        
        # use train data to train a tree
        X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.1)
        dtrain = xgb.DMatrix(X_train1, label=y_train1)
        dtest = xgb.DMatrix(X_train2, label=y_train2)
        evallist = [(dtest, 'eval'), (dtrain, 'train')]
        tree = xgb.train({'eta': 0.3, 'max_depth': 7, 'objective': 'reg:pseudohubererror'}, dtrain, 10000, evallist, verbose_eval=False)

        Yhat_unlabeled += (tree.predict(xgb.DMatrix(X_unlabeled)))/B
        
        other_inds = np.delete(range(n), train_ind)[:n-n_train]
        Yhat_labeled = tree.predict(xgb.DMatrix(X_labeled[other_inds, :]))

        grad_diff[j*(n-n_train):(j+1)*(n-n_train), :] = np.diag(Yhat_labeled - Y_labeled[other_inds]) @ Xols_labeled[other_inds, :]
    

    Hessian = 1/N * Xols_unlabeled.T @ Xols_unlabeled
    inv_Hessian = np.matrix(np.linalg.inv(Hessian))
    
    
    grads_til = np.zeros(Xols_unlabeled.shape)

    for i in range(N):
        grads_til[i,:] = (np.dot(Xols_unlabeled[i,:], thetaPP) - Yhat_unlabeled[i]) * Xols_unlabeled[i,:]
    var_unlabeled = np.matrix(np.cov(grads_til.T))
    
    var_labeled = np.matrix(np.cov(grad_diff.T))
    
    Sigma_hat = inv_Hessian @ (n/N * var_unlabeled + var_labeled) @ inv_Hessian
    
    return Sigma_hat

In [None]:
def classical_ols_interval(X, Y, alpha):
    # classical CLT interval
    
    n = X.shape[0]
    thetahat = ols(X, Y)
    Sigmainv = np.linalg.inv(1/n * X.T@X)
    M = 1/n * (X.T*((Y - X@thetahat)**2)[None,:])@X
    V = Sigmainv@M@Sigmainv
    stderr = np.sqrt(np.diag(V))
    halfwidth = norm.ppf(1-alpha/2) * stderr/np.sqrt(n)
    return thetahat - halfwidth, thetahat + halfwidth

In [None]:
# get data. we only look at year 2019.
features = ['AGEP','SCHL','MAR','DIS','ESP','CIT','MIG','MIL','ANC1P','NATIVITY','DEAR','DEYE','DREM','SEX','RAC1P', 'SOCP', 'COW']
ft = np.array(["q", "q", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c", "c"])


def get_data(year,features,outcome, randperm=True):
    # Predict income and regress to time to work
    data_source = folktables.ACSDataSource(survey_year=year, horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    income_features = acs_data[features].fillna(-1)
    income = acs_data[outcome].fillna(-1)
    employed = np.isin(acs_data['COW'], np.array([1,2,3,4,5,6,7]))
    if randperm:
        shuffler = np.random.permutation(income.shape[0])
        income_features, income, employed = income_features.iloc[shuffler], income.iloc[shuffler], employed[shuffler]
    return income_features, income, employed


def transform_features(features, ft, enc=None):
    c_features = features.T[ft == "c"].T.astype(str)
    if enc is None:
        enc = OneHotEncoder(handle_unknown='ignore', drop='if_binary', sparse=False)
        enc.fit(c_features)
    c_features = enc.transform(c_features)
    features = scipy.sparse.csc_matrix(np.concatenate([features.T[ft == "q"].T.astype(float), c_features], axis=1))
    return features, enc

## Construct confidence intervals

In [None]:
income_features, income, employed = get_data(year=2019, features=features, outcome='PINCP')
income_features_enc, enc = transform_features(income_features, ft)

age = income_features['AGEP'].to_numpy()
sex = income_features['SEX'].to_numpy()

true_val = ols(np.stack([age, sex], axis=1), income.to_numpy())


num_trials = 10
alpha = 0.1
ps = [0.1, 0.2, 0.3] # fraction of labeled data

theta_true = true_val[0]

df_list = []
        
# store results
columns = ["lb","ub","coverage","estimator","n"]

filename = "./census_results/simulation_results.csv"

results = []

for p in ps:
    
    n = int(p*len(income))

        
    for i in range(num_trials):
        ci, ppi, cppi = trial(income_features, income, n, alpha, enc)
        temp_df = pd.DataFrame(np.zeros((3,len(columns))), columns=columns)
        temp_df.loc[0] = cppi[0][0], cppi[1][0], (cppi[0][0] <= theta_true) & (theta_true <= cppi[1][0]), "cross-prediction", n
        temp_df.loc[1] = ci[0][0], ci[1][0], (ci[0][0] <= theta_true) & (theta_true <= ci[1][0]), "classical", n
        temp_df.loc[2] = ppi[0][0], ppi[1][0], (ppi[0][0] <= theta_true) & (theta_true <= ppi[1][0]), "PPI", n
        results += [temp_df]
        
df = pd.concat(results)
df["width"] = df["ub"] - df["lb"]
df_list += [df]
os.makedirs('./census_results/', exist_ok=True)
        
final_df = pd.concat(df_list, ignore_index=True)

final_df.to_csv(filename)

## Plot results

In [None]:
alpha = 0.1
col = np.array([sns.color_palette("Set2")[1], sns.color_palette("Set2")[2], sns.color_palette("Set2")[0]])
sns.set_theme(font_scale=1.4, style='white', palette=col, rc={'lines.linewidth': 3})

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,3.3))
sns.lineplot(ax=axs[0],data=final_df, x='n', y='coverage', hue='estimator', alpha=0.9, errorbar=None, marker="*", markersize=14)
sns.lineplot(ax=axs[1],data=final_df, x='n', y='width', hue='estimator', alpha=0.9, marker="*", markersize=14)

axs[0].axhline(1-alpha, color="#888888", linestyle='dashed', zorder=1, alpha=0.9)
handles, labels = axs[1].get_legend_handles_labels()
axs[1].legend(handles=handles, labels=labels)
axs[0].get_legend().remove()
axs[0].set_ylim([0.5,1])

for i in range(2):
    for j in range(3):
        axs[i].lines[j].set_linestyle("--")


sns.despine(top=True, right=True)
plt.tight_layout()

# save plot
plt.savefig('./census_results/ols_comparison_census.pdf')
plt.show()

In [None]:
# for reading data after it has been saved
datadir = './census_results/'
filenames = os.listdir(datadir)
data = [ pd.read_csv(os.path.join(datadir, fn)) for fn in filenames if 'simulation_results.' in fn ]
final_df = pd.concat(data, axis=0, ignore_index=True)