Tests for mod_mod algo

In [15]:
import torch
from torch import distributions as dist

import itertools
import pickle
import importlib

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models
import helpers, opt, objectives

In [2]:
with open('../inputs/phoq.pkl', 'rb') as f:
    t = pickle.load(f)

X = t[0] # one-hot encoding of X
T = t[1] # tokenized encoding of X
y = t[3].values

In [3]:
np.random.seed(1)
rand_inds = np.random.choice(len(X), 100, replace=True) # generate random indices for 100 X's to sample from
X_train = X[rand_inds]
y_train = y[rand_inds]
X_test = X
y_true = y

dic, means = helpers.get_predictions(X_train, y_train, X_test, its=500)

Iteration 500 of 500	NLML: 41.3445	sn: 0.544161	

In [4]:
aas = 'ARNDCQEGHILKMFPSTWYV'
ground = [(aa, i) for aa in aas for i in range(4)]

In [5]:
len_X = 35
len_center = 5

np.random.seed(1)
X = np.random.choice(len(ground), len_X, replace=False)
X = [ground[x] for x in X]

np.random.seed(12)
center = np.random.choice(len(ground), len_center, replace=False)
center = [ground[c] for c in center]

In [28]:
seed = helpers.get_seed(dic)
seed

[('S', 0), ('A', 1), ('S', 2), ('L', 3)]

In [7]:
perm = helpers.make_perm(ground, center)

#1. Tests for modular upper bound

In [137]:
opt = importlib.reload(opt)
# Make sure upper bound is indeed upper bound
# obj_LHS
obj_left = opt.obj_LHS(X, 4, dic)
obj_left_upper = opt.mod_upper(X, opt.obj_LHS, center, ground, 4, dic)
assert obj_left_upper >= obj_left 

# obj_RHS
obj_right = opt.obj_RHS(X, 4, dic, 100)
obj_right_upper = opt.mod_upper(X, opt.obj_RHS, center, ground, 4, dic, 100)
assert obj_right_upper >= obj_right

# check tightness?

In [138]:
obj_left_upper, obj_left

(tensor(12.9532), tensor(-7.2711))

In [139]:
# Make sure modular upper bound is modular
opt = importlib.reload(opt)

# obj_LHS
mod_up1 = torch.stack([opt.mod_upper([x], opt.obj_LHS, center, ground, 4, dic) for x in X]).sum()
mod_up1 -= (len(X) - 1) * opt.mod_upper([], opt.obj_LHS, center, ground, 4, dic) # empty set
assert np.isclose(np.array(obj_left_upper), np.array(mod_up1))

# obj_RHS
mod_up2 = torch.stack([opt.mod_upper([x], opt.obj_RHS, center, ground, 4, dic, 100) for x in X]).sum()
mod_up2 -= (len(X) - 1) * opt.mod_upper([], opt.obj_RHS, center, ground, 4, dic, 100) # empty set
assert np.isclose(np.array(obj_right_upper), np.array(mod_up2))

In [140]:
# make sure when X is passed in as center, obj is returned

# obj_LHS
assert obj_left == opt.mod_upper(X, opt.obj_LHS, X, ground, 4, dic) 

# obj_RHS
assert obj_right == opt.mod_upper(X, opt.obj_RHS, X, ground, 4, dic, 100) 

#2. Tests for modular lower bound

In [129]:
# Make sure lower bound is indeed lower bound
opt = importlib.reload(opt)

# obj_LHS
obj_left_lower = opt.mod_lower(seed, opt.obj_LHS, perm, 4, dic)
assert obj_left_lower < obj_left 

# obj_RHS
obj_right_lower = opt.mod_lower(seed, opt.obj_RHS, perm, 4, dic, 100)
assert obj_right_lower < obj_right 

# Make sure modular lower bound is modular
opt = importlib.reload(opt)

# obj_LHS
mod_down1 = torch.stack([opt.mod_lower([x], opt.obj_LHS, perm, 4, dic) for x in seed]).sum()
assert np.isclose(np.array(obj_left_lower), np.array(mod_down1))

# obj_RHS
mod_down2 = torch.stack([opt.mod_lower([x], opt.obj_RHS, perm, 4, dic, 100) for x in seed]).sum()
assert np.isclose(np.array(obj_right_lower), np.array(mod_down2))

In [201]:
opt = importlib.reload(opt)
p = opt.mod_mod(ground, opt.obj_LHS, opt.obj_RHS, seed, 
                           fn_args=(4, dic), g_args=(4, dic, 100), verbose=True)

Iteration 0	 obj = -0.427525	


In [216]:
opt = importlib.reload(opt)
p = opt.make_permuted_indices(ground, seed)
p[0]

array([35, 55, 15, 70, 64,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,
       13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
       31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
       49, 50, 51, 52, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
       67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79])

In [222]:
%%time 
opt = importlib.reload(opt)
np.random.seed(10645)

lib, obj_lst = opt.mod_mod(ground, opt.obj_LHS, opt.obj_RHS, seed, 
                           fn_args=(4, dic), g_args=(4, dic, 100), verbose=True)

Iteration 0	 obj = -0.427525	
Iteration 1	 obj = -2.556946
Iteration 2	 obj = -4.130598
Iteration 3	 obj = -4.971603
Iteration 4	 obj = -5.573421
Iteration 5	 obj = -6.597044
Iteration 6	 obj = -7.283101
Iteration 7	 obj = -7.829210
Iteration 8	 obj = -8.293918
Iteration 9	 obj = -8.500544
Iteration 10	 obj = -8.516885
Iteration 11	 obj = -8.520318
Iteration 12	 obj = -8.580872
Iteration 13	 obj = -7.922697
Iteration 14	 obj = -8.932904
Iteration 15	 obj = -8.633686
Iteration 16	 obj = -9.624880
Iteration 17	 obj = -10.172570
Iteration 18	 obj = -10.291595
Iteration 19	 obj = -10.321615
Iteration 20	 obj = -10.506092
Iteration 21	 obj = -8.205212
Iteration 22	 obj = -10.506092
Iteration 23	 obj = -8.475809
Iteration 24	 obj = -10.506092
CPU times: user 2h 26min 33s, sys: 3min 46s, total: 2h 30min 20s
Wall time: 2h 30min 50s


In [143]:
opt.get_N(lib, 4)

60

In [24]:
%%time
%load_ext line_profiler

opt = importlib.reload(opt)
objectives = importlib.reload(objectives)
helpers = importlib.reload(helpers)

np.random.seed(10645)

%lprun -f opt.mod_lower X_list, obj_list = opt.mod_mod(ground, objectives.obj_LHS, objectives.obj_RHS, seed, fn_args=(4, dic), g_args=(4, dic, 100), verbose=True)

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler
Iteration 0	 obj = -1.606132	
CPU times: user 3min 49s, sys: 7.27 s, total: 3min 56s
Wall time: 5min 27s


In [13]:
objectives.obj_LHS(X_list[-1], 4, dic) - objectives.obj_RHS(X_list[-1], 4, dic, 100)

tensor(-1.6061)

In [29]:
objectives = importlib.reload(objectives)
opt = importlib.reload(opt)
helpers = importlib.reload(helpers)


opt.greedy(objectives.objective, seed, obj_args=(4, dic, 100))

([('S', 0),
  ('A', 1),
  ('S', 2),
  ('L', 3),
  ('I', 2),
  ('S', 1),
  ('W', 3),
  ('C', 2),
  ('T', 1),
  ('K', 3),
  ('S', 3),
  ('Y', 3),
  ('C', 3),
  ('G', 3),
  ('I', 3),
  ('F', 3)],
 tensor(-1.6061))

In [224]:
# Check that cant add or take away anything to increase objective

objects = [] # store obj_LHS, obj_RHS, difference btw the two
for i in ground:
    if i in lib:
        X_noi = [x for x in lib if x != i]    
        obj_left = opt.obj_LHS(X_noi, 4, dic)
        obj_right = opt.obj_RHS(X_noi, 4, dic, 100)
        obj = obj_left - obj_right
        
        objects.append((obj_left, obj_right, obj))
    else:        
        obj_left = opt.obj_LHS(lib + [i], 4, dic)
        obj_right = opt.obj_RHS(lib + [i], 4, dic, 100)
        obj = obj_left - obj_right
        objects.append((obj_left, obj_right, obj))

In [227]:
for obj in objects:
    if obj[2] < -10.5061:
        print(obj)

In [30]:
for mod, obj in zip(mod_bounds, objects):
    print(mod)
    print(obj)
    print("--------------------")

(tensor(-0.4275), tensor(-78.4915), tensor(78.0639))
(tensor(-0.6064), tensor(1.00000e-31 *
       -4.7839), tensor(-0.6064))
--------------------
(tensor(-0.4275), tensor(-79.2118), tensor(78.7843))
(tensor(-0.7961), tensor(1.00000e-31 *
       -6.2798), tensor(-0.7961))
--------------------
(tensor(-0.4275), tensor(-80.2173), tensor(79.7898))
(tensor(-0.4943), tensor(1.00000e-31 *
       -3.8995), tensor(-0.4943))
--------------------
(tensor(-0.4275), tensor(-77.8094), tensor(77.3819))
(tensor(-0.4997), tensor(1.00000e-31 *
       -3.9422), tensor(-0.4997))
--------------------
(tensor(-0.4275), tensor(-78.3850), tensor(77.9575))
(tensor(-0.4486), tensor(1.00000e-31 *
       -3.5385), tensor(-0.4486))
--------------------
(tensor(-0.4275), tensor(-78.0639), tensor(77.6364))
(tensor(-0.6180), tensor(1.00000e-31 *
       -4.8753), tensor(-0.6180))
--------------------
(tensor(-0.4275), tensor(-78.3268), tensor(77.8993))
(tensor(-0.5536), tensor(1.00000e-31 *
       -4.3671), tensor(-0

In [20]:
len(X)

35