In [1]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import ConvergenceWarning

import numpy as np
import warnings

from data import name_to_dataset
from models import prepend

# Load adult dataset as a testbed
dataset = name_to_dataset("employmentCA_ras")
groups = dataset.groups
tree = dataset.tree
group_names = dataset.group_names

In [2]:
def group_error(ypred, ytrue, indices):
    return np.mean(ypred[indices] != ytrue[indices])

def std_err(n:int, e:float):
    """Return the lower and upper bound on error rate when test set size is n and empirical error rate is e"""
    assert e >= 0. and e <= 1 and n >= 0, f'Invalid input: n={n}, e={e}'
    a = 4.+n
    b = 2.+n*e
    c = n*e**2
    d = 2.*np.sqrt(1.+n*e*(1.-e))
    return ((b-d)/a, (b+d)/a)

def treepend(models, tree, X_train, y_train, gps_train,
             X_test, y_test, gps_test, group_names,
             epsilon=0, verbose=False):
    """
    Runs the MGL-Tree algorithm for already fitted models in `model`.

    Args:
        models: fitted sklearn-type models with a .fit() and a .predict()
        tree: a list of lists designating which groups in gps_train and gps_test are in each level of the tree.
        X_train: full training dataset
        y_train: full training labels
        gps_train: list of Boolean arrays for indexing X_train, y_train by group.
        X_test: full test dataset
        y_test: full test labels
        gps_test: list of Boolean arrays for indexing X_test, y_test by group
        group_names: name for each group
        epsilon: tolerance for new predictor
    """
    declist = [0]
    dectree = [[0] * len(level) for level in tree]

    num_groups = len(gps_train)
    assert(num_groups == len(models))
    assert(num_groups == len(gps_test))
    assert(num_groups == len(group_names))

    H_train = {}     # predictions of group-wise models on training data
    H_test = {}      # predictions of group-wise models on test data
    H_train_err = {} # number of groups in test 
    ng_test = {}     # number of samples in test for a group

    # Get predictions for every model on the train and test set
    for g in range(num_groups):
        if models[g]:   # Possible that a group is empty
            H_train[g] = models[g].predict(X_train)
            H_test[g] = models[g].predict(X_test)
            diff = H_train[g][gps_train[g]] != y_train[gps_train[g]]
            H_train_err[g] = np.mean(diff)
            ng_test[g] = np.sum(gps_test[g])
        else:
            H_train_err[g] = np.inf
    
    # Initialize predictions for the tree predictor
    F_train = H_train[0].copy()
    F_test = H_test[0].copy()
    F_train_err = {}
    for g in range(num_groups):
        diff = F_train[gps_train[g]] != y_train[gps_train[g]]
        F_train_err[g] = np.mean(diff)

    # BFS through the tree
    for i, level in enumerate(tree):
        for j, g in enumerate(level):
            if H_train_err[g] < F_train_err[g] + epsilon:
                declist.insert(0, g)
                dectree[i][j] = g
                F_train[gps_train[g]] = H_train[g][gps_train[g]]
                F_test[gps_test[g]] = H_test[g][gps_test[g]]
                for g in range(num_groups):
                    diff = F_train[gps_train[g]] != y_train[gps_train[g]]
                    F_train_err[g] = np.mean(diff)

    # Find test error for each group
    F_test_err = {}
    for g in range(num_groups):
        if models[g]:
            diff = F_test[gps_test[g]] != y_test[gps_test[g]]
            F_test_err[g] = np.mean(diff)
            if verbose:
                print('TREE group {0} ({4}): {1} (+/-{2}; n={3})'.format(
                    g, F_test_err[g], std_err(F_test_err[g], ng_test[g]), ng_test[g], group_names[g]))
            elif verbose:
                print("TREE group {} had no data!".format(g))

    return declist, dectree, F_test_err

In [3]:
# Run treepend
warnings.simplefilter("ignore", category=ConvergenceWarning)

X = dataset.X
y = dataset.y
scaler = StandardScaler(with_mean=False)
X = scaler.fit_transform(X)

splits = train_test_split(*tuple([X, y] + dataset.groups), 
                              test_size=0.2, random_state=0)
X_train = splits[0]
X_test = splits[1]
y_train = splits[2]
y_test = splits[3]
groups_train = splits[4::2]
groups_test = splits[5::2]

# TODO: CHANGE THIS
X_val = splits[1]
y_val = splits[3]
groups_val = splits[5::2]

# Fitting model for each group
group_models = []
for g, group_name in enumerate(dataset.group_names):
    n_g = np.sum(groups_train[g])
    #print("\tFitting group={} with n={}...".format(group_name, n_g))
    model = LogisticRegression()
    
    if np.sum(groups_train[g]) > 0 and len(np.unique(y_train[groups_train[g]])) > 1:
        model.fit(X_train[groups_train[g]], y_train[groups_train[g]])
    else:
        model = None
    group_models.append(model)

# Treepend
print("Fitting TREE model...")
treepend_results = treepend(group_models, tree, X_val, y_val, groups_val, 
                            X_test, y_test, groups_test, dataset.group_names)
tree_declist = treepend_results[0]
dectree = treepend_results[1]
tree_errs = treepend_results[2]
print("\tResulting decision list: {}".format(tree_declist))

Fitting TREE model...
	Resulting decision list: [46, 43, 39, 36, 30, 29, 27, 26, 25, 23, 21, 20, 19, 18, 15, 14, 13, 12, 11, 10, 9, 8, 4, 1, 0]


In [4]:
# Run treepend
warnings.simplefilter("ignore", category=ConvergenceWarning)

X = dataset.X
y = dataset.y
scaler = StandardScaler(with_mean=False)
X = scaler.fit_transform(X)

splits = train_test_split(*tuple([X, y] + dataset.groups), 
                              test_size=0.2, random_state=1)
X_train = splits[0]
X_test = splits[1]
y_train = splits[2]
y_test = splits[3]
groups_train = splits[4::2]
groups_test = splits[5::2]

# TODO: CHANGE THIS
splits = train_test_split(*tuple([X_train, y_train] + groups_train),
                          test_size=0.5, random_state=1)
X_train = splits[0]
X_val = splits[1]
y_train = splits[2]
y_val = splits[3]
groups_train = splits[4::2]
groups_val = splits[5::2]

# Fitting model for each group
group_models = []
for g, group_name in enumerate(dataset.group_names):
    n_g = np.sum(groups_train[g])
    #print("\tFitting group={} with n={}...".format(group_name, n_g))
    model = LogisticRegression()
    
    if np.sum(groups_train[g]) > 0 and len(np.unique(y_train[groups_train[g]])) > 1:
        model.fit(X_train[groups_train[g]], y_train[groups_train[g]])
    else:
        model = None
    group_models.append(model)


# Prepend
print("Fitting PREPEND model...")
prepend_results = prepend(group_models, X_train, y_train, groups_train, 
                            X_test, y_test, groups_test, dataset.group_names)
prepend_declist = prepend_results[0]
prepend_errs = prepend_results[1]
print("\tResulting decision list: {}".format(prepend_declist))

# Treepend
print("Fitting TREE model...")
treepend_results = treepend(group_models, tree, X_test, y_test, groups_test, 
                            X_test, y_test, groups_test, dataset.group_names)
tree_declist = treepend_results[0]
dectree = treepend_results[1]
tree_errs = treepend_results[2]
print("\tResulting decision list: {}".format(tree_declist))

# Evaluate on each group
for g, group_name in enumerate(dataset.group_names):
    n_g = np.sum(groups_test[g])
    print("\n=== Error on G{}: {} (n_g={}) ===\n".format(g, group_name, n_g))
    y_g = y_test[groups_test[g]]
    erm_pred = group_models[0].predict(X_test[groups_test[g]])
    g_erm_pred = group_models[g].predict(X_test[groups_test[g]])

    erm_err = np.mean(y_g != erm_pred)
    erm_std = std_err(n_g, erm_err)
    g_erm_err = np.mean(y_g != g_erm_pred)
    g_erm_std = std_err(n_g, g_erm_err)
    prepend_err = prepend_errs[g]
    prepend_std = std_err(n_g, prepend_err)
    treepend_err = tree_errs[g]
    treepend_std = std_err(n_g, treepend_err)

    print("\tGlobal ERM = {} +/- {}".format(erm_err, erm_std))
    print("\tGroup ERM = {} +/- {}".format(g_erm_err, g_erm_std))
    print("\tPrepend = {} +/- {}".format(prepend_err, prepend_std))
    print("\tTreepend = {} +/- {}".format(treepend_err, treepend_std))

Fitting PREPEND model...
	Resulting decision list: [26, 10, 14, 11, 22, 26, 34, 42, 45, 23, 43, 46, 37, 49, 47, 25, 35, 27, 38, 40, 30, 31, 33, 28, 48, 24, 39, 41, 36, 44, 32, 0]
Fitting TREE model...
	Resulting decision list: [49, 43, 42, 38, 37, 28, 26, 24, 23, 22, 21, 20, 19, 18, 16, 15, 14, 12, 11, 10, 9, 8, 1, 0]

=== Error on G0: ALL (n_g=75207) ===

	Global ERM = 0.21152286356323213 +/- (0.208559906904022, 0.21451650479107573)
	Group ERM = 0.21152286356323213 +/- (0.208559906904022, 0.21451650479107573)
	Prepend = 0.1953807491323946 +/- (0.1925053898001942, 0.19828851002828835)
	Treepend = 0.19478240057441462 +/- (0.19191043054458562, 0.19768683581272914)

=== Error on G1: R1 (n_g=46260) ===

	Global ERM = 0.20946822308690013 +/- (0.2057094681205318, 0.213277216991002)
	Group ERM = 0.2085603112840467 +/- (0.20480767575871325, 0.2123633427437941)
	Prepend = 0.193558149589278 +/- (0.18991087467533, 0.19725841462088306)
	Treepend = 0.193558149589278 +/- (0.18991087467533, 0.1972584