In [None]:
%load_ext autoreload
%autoreload 2
import sys, os
sys.path.insert(1, '../')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sns
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pdb
import xgboost as xgb
from scipy.stats import norm
from scipy.stats import multivariate_normal
from scipy.stats import bernoulli

## Implementation of classical inference, PPI with data splitting, and cross-prediction

In [None]:
def trial(X, Y, n, d_inf, alpha):
    # one trial; randomly splits data into labeled and unlabeled data and then runs baselines
    
    X_labeled, X_unlabeled, Y_labeled, Y_unlabeled = train_test_split(X, Y, train_size=n)

    classical_interval = classical_ols_interval(X_labeled[:,:d_inf], Y_labeled, alpha)
    
    pp_interval = pp_ols_interval(X_labeled, X_unlabeled, Y_labeled, alpha, d_inf, int(0.5*n))
    
    cpp_interval = cross_prediction_ols_interval(X_labeled, X_unlabeled, Y_labeled, alpha, d_inf)
    

    return classical_interval, pp_interval, cpp_interval

In [None]:
def ols(features, outcome):
    ols_coeffs = np.linalg.pinv(features).dot(outcome)
    return ols_coeffs

In [None]:
def pp_ols_interval(X_labeled, X_unlabeled, Y_labeled, alpha, d_inf, n_tr):
    # performs data splitting and then runs PPI with CLT interval
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    X_train, X_val, Y_train, Y_val = train_test_split(X_labeled, Y_labeled, train_size=n_tr)
    
    X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.1)
    dtrain = xgb.DMatrix(X_train1, label=y_train1)
    dtest = xgb.DMatrix(X_train2, label=y_train2)
    param = {'max_depth': 7, 'eta': 0.1, 'objective': 'reg:squarederror', 'eval_metric': ['error', 'mae']}
    evallist = [(dtest, 'eval'), (dtrain, 'train')]
    num_round = 500
    tree = xgb.train(param, dtrain, num_round, evallist, verbose_eval=False)

    
    Yhat_unlabeled = tree.predict(xgb.DMatrix(X_unlabeled))
    Yhat_val = tree.predict(xgb.DMatrix(X_val))
    
    
    thetaPP = rectified_ols_point_estimate(X_val[:,:d_inf], X_unlabeled[:,:d_inf], Y_val, Yhat_val, Yhat_unlabeled)

    Hessian = 1/N * X_unlabeled[:,:d_inf].T @ X_unlabeled[:,:d_inf]
    inv_Hessian = np.matrix(np.linalg.inv(Hessian))
    

    grads_til = np.diag(np.dot(X_unlabeled[:,:d_inf], thetaPP) - Yhat_unlabeled) @ X_unlabeled[:,:d_inf]
    var_unlabeled = np.matrix(np.cov(grads_til.T))
    
    pred_error = Yhat_val - Y_val
    grad_diff = np.diag(pred_error) @ X_val[:,:d_inf]
    var_labeled = np.matrix(np.cov(grad_diff.T))

    Sigma_hat = inv_Hessian @ ((n - n_tr)/N * var_unlabeled + var_labeled) @ inv_Hessian
    
    
    halfwidth = norm.ppf(1-alpha/2) * np.sqrt(np.diag(Sigma_hat)/(n - n_tr))
    
    return [thetaPP - halfwidth, thetaPP + halfwidth]

In [None]:
def rectified_ols_point_estimate(X_labeled, X_unlabeled, Y_labeled, Yhat_labeled, Yhat_unlabeled):
    # computes PPI point estimate; same subroutine is used for cross-prediction
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    
    bias = 1/n * X_labeled.T @ (Yhat_labeled - Y_labeled)
    XTX_inv = np.linalg.inv(1/N * X_unlabeled.T @ X_unlabeled)
    XTy = 1/N * X_unlabeled.T @ Yhat_unlabeled

    return XTX_inv @ (XTy - bias)

In [None]:
def cross_prediction_ols_interval(X_labeled, X_unlabeled, Y_labeled, alpha, d_inf, K = 10):
    # cross-prediction
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]
    
    fold_n = int(n/K)
    
    Yhat_labeled = np.zeros(n)
    Yhat_unlabeled = np.zeros(N)
    Yhat_avg_labeled = np.zeros(n)
    
    for j in range(K):
    
        X_val = X_labeled[j*fold_n:(j+1)*fold_n,:]
        Y_val = Y_labeled[j*fold_n:(j+1)*fold_n]
        train_ind = np.delete(range(n),range(j*fold_n,(j+1)*fold_n))
        X_train = X_labeled[train_ind,:]
        Y_train = Y_labeled[train_ind]

        # use train data to train a tree
        X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.1)
        dtrain = xgb.DMatrix(X_train1, label=y_train1)
        dtest = xgb.DMatrix(X_train2, label=y_train2)
        param = {'max_depth': 7, 'eta': 0.1, 'objective': 'reg:squarederror', 'eval_metric': ['error', 'mae']}
        evallist = [(dtest, 'eval'), (dtrain, 'train')]
        num_round = 500
        tree = xgb.train(param, dtrain, num_round, evallist, verbose_eval=False)


        Yhat_unlabeled += (tree.predict(xgb.DMatrix(X_unlabeled)))/K
        Yhat_labeled[j*fold_n:(j+1)*fold_n] = tree.predict(xgb.DMatrix(X_val))
    
    
    thetaPP = rectified_ols_point_estimate(X_labeled[:,:d_inf], X_unlabeled[:,:d_inf], Y_labeled, Yhat_labeled, Yhat_unlabeled)


    Sigma_hat = bootstrap_covariance(X_labeled, X_unlabeled, Y_labeled, n-fold_n, thetaPP)
    
    
    halfwidth = norm.ppf(1-alpha/2) * np.sqrt(np.diag(Sigma_hat)/n)
    
    return [thetaPP - halfwidth, thetaPP + halfwidth]

In [None]:
def bootstrap_covariance(X_labeled, X_unlabeled, Y_labeled, n_train, thetaPP, B = 30):
    # estimates the asymptotic variance of cross-prediction
    
    
    n = X_labeled.shape[0]
    N = X_unlabeled.shape[0]

    Yhat_labeled = np.zeros(n)
    Yhat_unlabeled = np.zeros(N)
    
    d_inf = len(thetaPP)
    
    grad_diff = np.zeros((int((n-n_train)*B), d_inf))
    
    for j in range(B):
        
        train_ind = np.random.choice(range(n),n_train)
        X_train = X_labeled[train_ind,:]
        Y_train = Y_labeled[train_ind]
        
        # use train data to train a tree
        X_train1, X_train2, y_train1, y_train2 = train_test_split(X_train, Y_train, test_size=0.1)
        dtrain = xgb.DMatrix(X_train1, label=y_train1)
        dtest = xgb.DMatrix(X_train2, label=y_train2)
        param = {'max_depth': 7, 'eta': 0.1, 'objective': 'reg:squarederror', 'eval_metric': ['error', 'mae']}
        evallist = [(dtest, 'eval'), (dtrain, 'train')]
        num_round = 500
        tree = xgb.train(param, dtrain, num_round, evallist, verbose_eval=False)

        Yhat_unlabeled += (tree.predict(xgb.DMatrix(X_unlabeled)))/B
        
        other_inds = np.delete(range(n), train_ind)[:n-n_train]
        Yhat_labeled = tree.predict(xgb.DMatrix(X_labeled[other_inds, :]))

        grad_diff[j*(n-n_train):(j+1)*(n-n_train), :] = np.diag(Yhat_labeled - Y_labeled[other_inds]) @ X_labeled[other_inds, :d_inf]
    

    Hessian = 1/N * X_unlabeled[:,:d_inf].T @ X_unlabeled[:,:d_inf]
    inv_Hessian = np.matrix(np.linalg.inv(Hessian))

    grads_til = np.diag(np.dot(X_unlabeled[:,:d_inf], thetaPP) - Yhat_unlabeled) @ X_unlabeled[:,:d_inf]
    var_unlabeled = np.matrix(np.cov(grads_til.T))
    
    var_labeled = np.matrix(np.cov(grad_diff.T))
    
    Sigma_hat = inv_Hessian @ (n/N * var_unlabeled + var_labeled) @ inv_Hessian
    
    return Sigma_hat

In [None]:
def classical_ols_interval(X, Y, alpha):
    # classical CLT interval
    
    n = X.shape[0]
    thetahat = ols(X, Y)
    Sigmainv = np.linalg.inv(1/n * X.T@X)
    M = 1/n * (X.T*((Y - X@thetahat)**2)[None,:])@X
    V = Sigmainv@M@Sigmainv
    stderr = np.sqrt(np.diag(V))
    halfwidth = norm.ppf(1-alpha/2) * stderr/np.sqrt(n)
    return thetahat - halfwidth, thetahat + halfwidth

## Main cell: generate data and form intervals

In [None]:
N = 10000 # size of unlabeled data
ns = np.array([100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]) # size of labeled data
num_trials = 2
alpha = 0.1

# parameters of data generating process:
d = 3 # total dimensionality of data
d_inf = 2 # dimension used in inference problem
beta = np.zeros(d)
Rsqs = [0, 0.5, 1]

var_y = 4

df_list = []
        
# Store results
columns = ["lb","ub","coverage","estimator","n","d", r'$R^2$']


for Rsq in Rsqs:
        
    filename = "ols_results/" + f"Rsq_{Rsq}".replace(".", "_") + ".csv"
    if os.path.exists(filename):
        continue


    beta[d_inf:] = np.sqrt(var_y * Rsq)
    beta[:d_inf] = 1
    
    # compute true target
    theta_true = beta[0]



    results = []
    for j in tqdm(range(ns.shape[0])):
        for i in range(num_trials):
            n = ns[j]

            X = multivariate_normal.rvs(mean=np.zeros(d), cov=np.eye(d), size=(n+N)) # feature matrix
            y = X @ beta + np.sqrt(var_y * (1-Rsq))*np.random.randn(n+N) # outcomes


            ci, ppi, cppi = trial(X, y, n, d_inf, alpha)
            
            temp_df = pd.DataFrame(np.zeros((3,len(columns))), columns=columns)
            temp_df.loc[0] = cppi[0][0], cppi[1][0], (cppi[0][0] <= theta_true) & (theta_true <= cppi[1][0]), "cross-prediction", n, d_inf, Rsq
            temp_df.loc[1] = ci[0][0], ci[1][0], (ci[0][0] <= theta_true) & (theta_true <= ci[1][0]), "classical", n, d_inf, Rsq
            temp_df.loc[2] = ppi[0][0], ppi[1][0], (ppi[0][0] <= theta_true) & (theta_true <= ppi[1][0]), "PPI", n, d_inf, Rsq
            results += [temp_df]

    df = pd.concat(results)
    df["width"] = df["ub"] - df["lb"]
    df_list += [df]
    os.makedirs('./ols_results/', exist_ok=True)
    df.to_csv(filename)
final_df = pd.concat(df_list, ignore_index=True)























100%|████████████████████████████████████████| 7/7 [14:00:47<00:00, 7206.72s/it]




























100%|████████████████████████████████████████| 7/7 [12:08:28<00:00, 6244.03s/it]


## Plot results

In [None]:
Rsqs = [0, 0.5, 1]
alpha = 0.1

# plots coverage and width as function of n and beta
col = np.array([sns.color_palette("Set2")[1], sns.color_palette("Set2")[2], sns.color_palette("Set2")[0]])
sns.set_theme(font_scale=1.4, style='white', palette=col, rc={'lines.linewidth': 3})
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(10,10))
sns.lineplot(ax=axs[0,0],data=final_df[(final_df[r'$R^2$'] == 0)], x='n', y='coverage', hue='estimator', alpha=0.9, errorbar=None)
sns.lineplot(ax=axs[0,1],data=final_df[ (final_df[r'$R^2$'] == 0)], x='n', y='width', hue='estimator', alpha=0.9)
sns.lineplot(ax=axs[1,0],data=final_df[(final_df[r'$R^2$'] == 0.5)], x='n', y='coverage', hue='estimator', alpha=0.9, errorbar=None)
sns.lineplot(ax=axs[1,1],data=final_df[ (final_df[r'$R^2$'] == 0.5)], x='n', y='width', hue='estimator', alpha=0.9)
sns.lineplot(ax=axs[2,0],data=final_df[(final_df[r'$R^2$'] == 1)], x='n', y='coverage', hue='estimator', alpha=0.9, errorbar=None)
sns.lineplot(ax=axs[2,1],data=final_df[ (final_df[r'$R^2$'] == 1)], x='n', y='width', hue='estimator', alpha=0.9)


grid = plt.GridSpec(3, 1)
for i in range(3):
    # create fake subplot just to title set of subplots
    fake = fig.add_subplot(grid[i])
    # '\n' is important
    fake.set_title(r'$R_0^2$' + f' = {Rsqs[i]}\n', size=18)
    fake.set_axis_off()
    
    
for i in range(axs.shape[0]):
    axs[i,0].axhline(1-alpha, color="#888888", linestyle='dashed', zorder=1, alpha=0.9)
    for j in range(axs.shape[1]):
        if (i == 0) & (j == 1):
            # remove the legend title
            handles, labels = axs[i,j].get_legend_handles_labels()
            axs[i,j].legend(handles=handles, labels=labels)
        else:
            # remove the legend
            axs[i,j].get_legend().remove()
            axs[i,0].set_ylim([0.5,1])

sns.despine(top=True, right=True)
plt.tight_layout()

# save plot
plt.savefig('./ols_results/ols_comparison.pdf')
plt.show()

In [None]:
# for reading data after it has been saved
datadir = './ols_results/'
filenames = os.listdir(datadir)
data = [ pd.read_csv(os.path.join(datadir, fn)) for fn in filenames if 'Rsq' in fn ]
final_df = pd.concat(data, axis=0, ignore_index=True)