In [1]:
import numpy as onp
import jax.numpy as np

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
sns.set_style("whitegrid")
plt.rc("axes.spines", top=False, right=False)
sns.set_context("paper")

from sklearn.preprocessing import PolynomialFeatures
from scipy.stats import multivariate_normal

from copy import deepcopy

from jax.scipy.stats import multivariate_normal as mvn
from jax import jit, vmap, random, grad

from locreg import LocalRegression
from locreg import BootstrapLocalRegression
from locreg.main import gridsearch_fn
from locreg.supp import get_inv_infmat_m
from noisetests import nonparametric_test

from sklearn.model_selection import LeaveOneOut, GridSearchCV
from sklearn.linear_model import LogisticRegression
from typing import Callable



In [2]:
from scipy.stats import norm

def get_pvalue(x, mean, scale):
    z = (x - mean)/scale
    cdf = norm(0, 1).cdf(z)
    if z > 0:
        pvalue = 1 - cdf
    else:
        pvalue = cdf
    return 2*pvalue

In [3]:
import matplotlib as mpl

from jax.scipy.stats import multivariate_normal as mvn

def get_data_xor(n, d=1, gamma=1):
    mean0 = np.array([d, d])
    mean1 = np.array([-d, -d])
    mean2 = np.array([d, -d])
    mean3 = np.array([-d, d])
    
    cov = gamma*np.eye(2)
    
    x0 = onp.random.multivariate_normal(
        mean0, cov, n//4
    )
    
    x1 = onp.random.multivariate_normal(
        mean1, cov, n//4
    )
    
    x2 = onp.random.multivariate_normal(
        mean2, cov, n//4
    )
    
    x3 = onp.random.multivariate_normal(
        mean3, cov, n//4
    )
    
    X = np.concatenate((x0, x1, x2, x3), axis=0)
#     X[:, 0] += 3
    
    y = onp.ones((n, ))
    y[:n//2] = 0
    
    def optimum_classifier(z):
        c0 = mvn.pdf(z, mean0, cov)
        c1 = mvn.pdf(z, mean1, cov)
        c2 = mvn.pdf(z, mean2, cov)
        c3 = mvn.pdf(z, mean3, cov)
        return (c0+c1)/(c0+c1+c2+c3)
    
    return X, y, optimum_classifier, mean0-mean1


def plot_decision_boundary(clf, X, ax, mode='llr'):
    resolution=0.20
    extra = 0.25
    x1_min, x1_max = X[:, 0].min() - extra, X[:, 0].max() + extra
    x2_min, x2_max = X[:, 1].min() - extra, X[:, 1].max() + extra
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
                           np.arange(x2_min, x2_max, resolution))
    
    newx = np.array([xx1.ravel(), xx2.ravel()]).T
    
    if mode == 'llr':
        Z = clf.predict(newx)[:, 1]
    elif mode == 'bclf':
        Z = clf(newx)
        
    Z = Z.reshape(xx1.shape)
    
    cmap = mpl.cm.Spectral
    norm = mpl.colors.Normalize(vmin=0, vmax=1)
    
    v=ax.contourf(xx1, xx2, Z, alpha=0.3, cmap=cmap)

#     fig.colorbar(,
#              cax=ax, orientation='horizontal', label='Some Units')
    
    ax.set_xlim(xx1.min(), xx1.max())
    ax.set_ylim(xx2.min(), xx2.max())

    ax.set_xlim(-4, 4)
    ax.set_ylim(-4, 4)

    return v, ax

In [4]:
def add_noise(y, alpha, beta):
    ytilde = deepcopy(y)
    n = len(y)
    hn = n//2
    pos = np.where(y == 1)[0]
    
    # It is not exactly the true noise model - but shouldn't make a difference
    pos_sample = np.random.choice(hn, int(alpha*hn))
    
    neg = np.where(y == 0)[0]
    neg_sample = np.random.choice(hn, int(beta*hn))
    
    ytilde[pos[pos_sample]] = 0
    ytilde[neg[neg_sample]] = 1
    
    return ytilde

In [5]:
def prettyfloat(x, precision=4):
    return np.round(x, precision)

def _logistic(z):
    return 1/(1+np.exp(-z))

In [6]:
X, y, bclf, *_ = get_data_xor(100, d=2)



In [7]:
# clf = llr().fit(X, y)

In [8]:
# preds=clf.predict(X)

In [9]:
# import pickle as pkl
# sd = clf.state_dict()
# with open('sd_test.pkl', 'wb') as f:
#     pkl.dump(sd, f)

In [10]:
# with open('sd_test.pkl', 'rb') as f:
#     nsd = pkl.load(f)
# clf2 = llr().load_state_dict(nsd)

In [11]:
# preds2=clf2.predict(X)

In [12]:
# np.array_equal(preds, preds2)

In [13]:
# clf = bllr(n_estimators=5).fit(X, y)

In [14]:
# preds=clf.predict(X)

In [15]:
# import pickle as pkl
# sd = clf.state_dict()
# with open('sd_test.pkl', 'wb') as f:
#     pkl.dump(sd, f)

In [16]:
# with open('sd_test.pkl', 'rb') as f:
#     nsd = pkl.load(f)
# clf2 = bllr().load_state_dict(nsd)

In [17]:
# preds2=clf2.predict(X)

In [18]:
# np.array_equal(preds, preds2)

In [19]:
# param_grid = {
#     'bandwidth': [0.10, 0.50, 0.75]
# }

# gridsearch = gridsearch_help(X, y, param_grid=param_grid)

In [20]:
# clf = gridsearch.best_estimator_

In [21]:
# clf = bllr(n_estimators=1, kernel_kwargs_grid={'bandwidth': [0.10, 0.20, 0.50]}).fit(X, y)

In [22]:

# def nonparametric_test(clf, anchors, invtype):
#     n_anchors = len(anchors)
#     _omegas_sum = 0
#     _preds_sum = 0
#     js = []
#     var = 0.
#     for anchor in anchors:
#         mdl = clf.predict(anchor, return_models=True)

#         valid=True
#         try:
#             invj1, hj2, single_var, pred = get_inv_infmat_m(mdl[0], invtype=invtype)
#         except AssertionError as err:
#             # print(err)
#             print(cf)
#             print(modelid)
#             valid=False
#             break

#         js.append([invj1, hj2])

#         _preds_sum += pred[0][0]
#         assert single_var >= 0, print(single_var)

#         var += single_var

#     if valid:
#         if invtype in [22, 222]:
#             for ii in range(n_anchors):
#                 invj1_0, hj2_0 = js[ii]
#                 for jj in range(ii+1, n_anchors):
#                     invj1_1, hj2_1 = js[jj]
#                     cvar = 2 * (invj1_0 @ hj2_0 @ hj2_1.T @ invj1_1.T)[0, 0]
#                     # assert cvar > 0, print(cvar)
#                     var += cvar

#         scale = np.sqrt(var) / 4
#         scale = scale / n_anchors
#         pvalue = get_pvalue(_preds_sum/n_anchors, 0.50, scale)

#     return pvalue

In [23]:
def getcfs(n):
    cfs = []
    for ii in range(n):
        xx = onp.array([[0, 0.]])
        idx = onp.random.choice(2)
        loc = onp.random.choice(2)
        xx[0, idx] = onp.random.uniform(-4, 4)
        cfs.append(np.array(xx))
    return cfs

cfs = getcfs(4)

In [24]:
param_grid = {
    'bandwidth': [0.10, 0.20, 0.50]
}

gridsearch = gridsearch_fn(X, y, param_grid=param_grid, n_instances=-1)
mdl = gridsearch.best_estimator_

p = nonparametric_test(mdl, cfs, 22)

In [35]:
import pickle as pkl
sd = mdl.state_dict()
with open('sd_test.pkl', 'wb') as f:
    pkl.dump(sd, f)

In [36]:
with open('sd_test.pkl', 'rb') as f:
    nsd = pkl.load(f)
clf2 = LocalRegression().load_state_dict(nsd)

In [37]:

p = nonparametric_test(clf2, cfs, 22)

In [28]:
# s0 = set(mdl.__dir__())
# s1 = set(clf2.__dir__())

In [29]:
# p0 = mdl.predict(X)

In [30]:
# p1 = clf2.predict(X)

In [31]:
# p0 == p1