In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from utils import *

In [1]:
def proj(x, lb, ub):
    p = x.copy()

    for i in range(len(x)):
        if x[i] < lb[i]:
            p[i] = lb[i]
        elif x[i] > ub[i]:
            p[i] = ub[i]

    return p


def in_box(x, lb, ub):
    if (x == proj(x, lb, ub)).all():
        return 1
    else:
        return 0


def fit(X, Y):
    n = len(Y)
    X_copy = X.copy()
    # X_copy = np.append(X_copy, np.ones((n, 1)), axis = 1)

    XTX = X_copy.T @ X_copy
    beta_hat = np.linalg.solve(XTX, X_copy.T @ Y)
    min_eig = min(np.linalg.eigvals(XTX)) / len(X_copy)
    print(beta_hat)
    return beta_hat, min_eig


def cutoff(std, min_eig, n_core, n_full, alpha, x):
    d = len(x)
    return std * (np.linalg.norm(x) * np.sqrt(d * np.log(4 * d / alpha) / n_core) / min_eig + np.sqrt(2 * np.log(4 * n_full / alpha)))


def res(x, y, beta):
    # return abs(y - np.dot(beta[:-1], x) - beta[-1])
    return abs(y - np.dot(x, beta))


def directed_infty_norm(x, S):
    best = 0
    for j in range(len(x)):
        if S[j] != set():
            best = max(best, max([x[j] * s for s in S[j]]))
    return best


def largest_box_heuristic(X, B):
    # B is a d x 2 array containing the maximum allowed box.
    # B must contain the origin and the origin must be contained in the final selected region.
    # Any valid point x must have B[i,0] <= x[i] <= B[i,1].
    # All of the points in X should be within this box.
    n, d = X.shape
    X2 = X.copy()
    B2 = B.copy()
    S = [set([-1,1]) for j in range(d)]
    # print(S)

    while X2.any():
        directed_infty_norms = [directed_infty_norm(x, S) for x in X2]
        i = np.argmin(directed_infty_norms) # i = point which supports the new side
        j = list(np.abs(X2[i]) == directed_infty_norms[i]).index(True) # j = dimension which is being supported
        sign = int(np.sign(X2[i, j]))

        S[j].remove(sign)
        B2[j, int((sign + 1)/2)] = X2[i, j]
        # print(X2[i,j])
        # print(B2[j, int((sign + 1)/2)])

        X2 = X2[[k for k in range(len(X2)) if sign * X2[k, j] < directed_infty_norms[i]]]

    return B2

# Optimization-based approach
Here we implement the optimization-based approach to region finding. Let R denote (a parameterization of) our estimate for the region. The idea is to maximize $$\textrm{vol}(R) - \lambda \sum_{i=1}^n f_i(R),$$ where $f_i(R)$ are functions which penalize R for including rejected points and potentially reward R for including non-rejected points.

Note: We may want to include a normalized volume in the objective. When the region is small, the gradient of the volume will be close to vanishing, but when the region is large, the gradient of the volume will be huge.

## Hard-thresholding method
We first hard-threshold each training point based on a residual cutoff on the core fit, giving a binary label $r_i \in \{0,1\}$ which is 1 if the i-th point is rejected and 0 if it is not. We then define $$f_i(R) = r_i \exp(-c_1 d(x_i, R)) - (1-r_i) \exp(-c_2 d(x_i, R)).$$
Here $c_i$ are constants which control how much a rejected point is penalized (resp. a non-rejected point is rewarded) for being close to the approximate region.

In [None]:
import torch

def vol(R):
    v = 1.
    for i in range(len(R)):
        v *= R[i, 1] - R[i, 0]  # Note: We might run into problems if lb > ub, watch for this.
    return v


def dist(x, R):
    return torch.linalg.norm(x - torch.clamp(x, R[:, 0], R[:, 1]), dim = len(x.shape) - 1)


def hard_obj(X, R, excluded, reg, c1 = 1., c2 = 1.):
    return vol(R) - reg * torch.sum(excluded * torch.exp(-c1 * dist(X, R)) - (1 - excluded) * torch.exp(-c2 * dist(X, R)))


def hard_train(X, init_R, excluded, reg, iters, lr, c1 = 1., c2 = 1.):
    R = init_R.clone()
    for i in range(iters):
        obj = hard_obj(X, R, excluded, reg, c1, c2)
        if R.grad is not None:
            R.zero_grad_()
        obj.backward()
        with torch.no_grad():
            R += lr * R.grad.data
    return R

In [None]:
dir_data = '../data/regr'

names_data = ['Dutch_drinking_inh', 'Dutch_drinking_wm', 'Dutch_drinking_sha', 'Brazil_health_heart', 
              'Brazil_health_stroke', 'Korea_grip', 'China_glucose_women2', 'China_glucose_men2', 
              'Spain_Hair', 'China_HIV']

for name_data in names_data:
    X, y, names_covariates = load_regr_data(name_data, dir_data)
    y = y.astype(np.float)
    print(name_data, X.shape)

In [None]:
X, Y, names_covariates = load_regr_data('Brazil_health_stroke', dir_data)
n = len(Y)
d = len(X[0])

X -= np.mean(X, axis = 0)
X = np.append(X, np.ones(n), axis = 1)

B = np.column_stack([np.min(X, axis = 0), np.max(X, axis = 0)])

test_ind = np.random.sample(range(n))
train_ind = [i for i in range(n) if i not in test_ind]

X_test = X[test_ind].copy()
Y_test = Y[test_ind].copy()

X = X[train_ind].copy()
Y = Y[train_ind].copy()

assert len(Y_test) + len(Y) == n

In [None]:
beta = np.linalg.solve(X.T @ X, X.T @ Y)
res  = np.abs(X @ beta - Y)
core_ind = np.argsort(res)[:int(n / 10)]

core_X = X[core_ind]
core_Y = Y[core_ind]

T = 10
for i in range(T):
    new_beta = np.linalg.solve(core_X.T @ core_X, core_X.T @ core_Y)
    if np.linalg.norm(new_beta - beta) < 1e-3:
        print(f'Found stable core ({i} iterations)')
        break

    else:
        res = np.abs(X @ new_beta - Y)
        new_core_ind = np.argsort(res)[:int(n / 10)]
        core_X = X[core_ind]
        core_y = Y[core_ind]
        beta = new_beta

In [None]:
core_beta, min_eig = fit(core_X, core_Y)

s_hat = np.sqrt(np.sum((core_X @ core_beta - core_Y) ** 2) / (len(core_Y) - len(X[0]) - 1))
alpha = 0.05


excluded = np.zeros(n)
for k in range(n):
    if res(X[k], y[k], core_beta) > cutoff(s_hat, min_eig, int(n / 10), n, alpha, x):
        excluded[k] = 1

approx_region = largest_box_heuristic(X[excluded == 1], B)

reg = 1.
iters = 100
lr = 0.1

init_R = 0.1 * torch.ones(d, 2)
init_R[:, 0] *= -1
init_R.requires_grad = True
opt_approx_region = hard_train(X, init_R, excluded, reg, iters, lr, c1 = 1., c2 = 1.)