Tests for mod_mod algo

In [1]:
import torch
from torch import distributions as dist

import itertools
import pickle

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models
import helpers, opt

In [2]:
with open('../inputs/phoq.pkl', 'rb') as f:
    t = pickle.load(f)

X = t[0] # one-hot encoding of X
T = t[1] # tokenized encoding of X
y = t[2].values

In [3]:
np.random.seed(1)
rand_inds = np.random.choice(len(X), 100, replace=True) # generate random indices for 100 X's to sample from
X_train = X[rand_inds]
y_train = y[rand_inds]
X_test = X
y_true = y

dic, means = helpers.get_predictions(X_train, y_train, X_test, its=500)

Iteration 500 of 500	NLML: 37.3695	sn: 0.179451	

In [19]:
aa = np.random.choice(np.array(list('ARNDCQEGHILKMFPSTWYV')), 20)
X = []
for i, j in zip(range(0, 20, 5), range(4)):
    for a in aa[i:i + 5]:
        X.append((a, j))
        
X = list(set(X))

In [5]:
center = [('F', 3), ('E', 0), ('N', 1), ('G', 2)]#, ('M', 3), ('S', 2), ('W', 2)]

aas = 'ARNDCQEGHILKMFPSTWYV'
ground = [(aa, i) for aa in aas for i in range(4)]

In [6]:
seed = helpers.get_seed(dic)

In [7]:
perm = opt.make_perm(ground, center)

#1. Tests for modular upper bound

In [20]:
# Make sure upper bound is indeed upper bound

# obj_LHS
obj_left = opt.obj_LHS(X, 4, dic)
obj_left_upper = opt.mod_upper(X, opt.obj_LHS, center, 4, dic)
assert obj_left_upper > obj_left 

# obj_RHS
obj_right = opt.obj_RHS(X, 4, dic, 100)
obj_right_upper = opt.mod_upper(X, opt.obj_RHS, center, 4, dic, 100)
assert obj_right_upper > obj_right

# check tightness?

In [21]:
# Make sure modular upper bound is modular

# obj_LHS
mod_up1 = torch.stack([opt.mod_upper([x], opt.obj_LHS, center, 4, dic) for x in X]).sum()
mod_up1 -= (len(X) - 1) * opt.mod_upper([], opt.obj_LHS, center, 4, dic) # empty set
assert np.isclose(np.array(obj_left_upper), np.array(mod_up1))

# obj_RHS
mod_up2 = torch.stack([opt.mod_upper([x], opt.obj_RHS, center, 4, dic, 100) for x in X]).sum()
mod_up2 -= (len(X) - 1) * opt.mod_upper([], opt.obj_RHS, center, 4, dic, 100) # empty set
assert np.isclose(np.array(obj_right_upper), np.array(mod_up2))

In [22]:
# make sure when X is passed in as center, obj is returned

# obj_LHS
assert obj_left == opt.mod_upper(X, opt.obj_LHS, X, 4, dic) 

# obj_RHS
assert obj_right == opt.mod_upper(X, opt.obj_RHS, X, 4, dic, 100) 

#2. Tests for modular lower bound

In [23]:
# Make sure lower bound is indeed lower bound

# obj_LHS
obj_left_lower = opt.mod_lower(X, opt.obj_LHS, perm, 4, dic)
assert obj_left_lower < obj_left 

# obj_RHS
obj_right_lower = opt.mod_lower(X, opt.obj_RHS, perm, 4, dic, 100)
assert obj_right_lower < obj_right 

# check tightness?

In [24]:
# Make sure modular lower bound is modular

# obj_LHS
mod_down1 = torch.stack([opt.mod_lower([x], opt.obj_LHS, perm, 4, dic) for x in X]).sum()
assert np.isclose(np.array(obj_left_lower), np.array(mod_down1))

# obj_RHS
mod_down2 = torch.stack([opt.mod_lower([x], opt.obj_RHS, perm, 4, dic, 100) for x in X]).sum()
assert np.isclose(np.array(obj_right_lower), np.array(mod_down2))

In [28]:
lib, obj_lst = opt.mod_mod(ground, opt.obj_LHS, opt.obj_RHS, seed, fn_args=(4, dic), g_args=(4, dic, 100), verbose=False)

In [35]:
# Check that cant add or take away anything to increase objective

mod_bounds = [] # store upper bound, lower bound, difference btw the two
objects = [] # store obj_LHS, obj_RHS, difference btw the two
for i in ground:
    if i in lib:
        X_noi = [x for x in lib if x != i]
        
        up = opt.mod_upper(X_noi, opt.obj_LHS, lib, 4, dic)
        down = opt.mod_lower(X_noi, opt.obj_RHS, perm, 4, dic, 100)
        mod = up - down
        
        mod_bounds.append((up, down, mod))
        assert obj_lst[-1] < mod
        
        obj_left = opt.obj_LHS(X_noi, 4, dic)
        obj_right = opt.obj_RHS(X_noi, 4, dic, 100)
        obj = obj_left - obj_right
        
        objects.append((obj_left, obj_right, obj))
    else:
        up = opt.mod_upper(lib + [i], opt.obj_LHS, lib, 4, dic)
        down = opt.mod_lower(lib + [i], opt.obj_RHS, perm, 4, dic, 100)
        mod = up - down
        
        mod_bounds.append((up, down, mod))
        assert obj_lst[-1] < mod
        
        obj_left = opt.obj_LHS(lib + [i], 4, dic)
        obj_right = opt.obj_RHS(lib + [i], 4, dic, 100)
        obj = obj_left - obj_right
        objects.append((obj_left, obj_right, obj))

In [37]:
for mod, obj in zip(mod_bounds, objects):
    print(mod)
    print(obj)
    print("--------------------")

(tensor(-0.4275), tensor(-99.8211), tensor(99.3936))
(tensor(-0.6064), tensor(1.00000e-31 *
       -4.7840), tensor(-0.6064))
--------------------
(tensor(-0.4275), tensor(-95.0813), tensor(94.6537))
(tensor(-0.7961), tensor(1.00000e-31 *
       -6.2798), tensor(-0.7961))
--------------------
(tensor(-0.4275), tensor(-96.0420), tensor(95.6145))
(tensor(-0.4943), tensor(1.00000e-31 *
       -3.8995), tensor(-0.4943))
--------------------
(tensor(-0.4275), tensor(-96.1003), tensor(95.6728))
(tensor(-0.4997), tensor(1.00000e-31 *
       -3.9422), tensor(-0.4997))
--------------------
(tensor(-0.4275), tensor(-94.9949), tensor(94.5674))
(tensor(-0.4486), tensor(1.00000e-31 *
       -3.5385), tensor(-0.4486))
--------------------
(tensor(-0.4275), tensor(-97.3731), tensor(96.9456))
(tensor(-0.6180), tensor(1.00000e-31 *
       -4.8753), tensor(-0.6180))
--------------------
(tensor(-0.4275), tensor(-97.8773), tensor(97.4498))
(tensor(-0.5536), tensor(1.00000e-31 *
       -4.3671), tensor(-0