Tests for mod_mod algo

In [1]:
import torch
from torch import distributions as dist

import itertools
import pickle
import importlib

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models
import helpers, opt, objectives

In [2]:
torch.set_default_tensor_type(torch.DoubleTensor)


In [3]:
with open('../inputs/gb1.pkl', 'rb') as f:
    t = pickle.load(f)
X = t[0]
A = t[1]
y = t[2]
wt = t[3]

In [4]:
np.random.seed(120120)
_ = torch.manual_seed(43298)

seq_to_x = {}
for i, x in enumerate(X):
    seq = helpers.decode_X(x)
    seq_to_x[seq] = i
wt_inds = [seq_to_x[wt]]

n = 100
train_inds = wt_inds + list(np.random.choice(len(X), n, replace=True))

y_train = y[train_inds]
y_true = y
A_train = A[train_inds]
A_test = A

inds = sorted(set(itertools.chain(train_inds)))


dic, _ = helpers.get_predictions(A[inds], y[inds], A_test,
                                         one_hots=X, its=3000, lr=1e-3)

Iteration 3000 of 3000	NLML: 49.3676	sn: 0.365388	

In [47]:
aas = 'ARNDCQEGHILKMFPSTWYV'
ground = [(aa, i) for aa in aas for i in range(4)]

In [42]:
seen_seqs = [helpers.decode_X(X[i]) for i in inds]
for s in seen_seqs:
    dic[s] = 0.0

In [49]:
seed = helpers.get_seed(dic)
seed

[('K', 0), ('Y', 1), ('Q', 2), ('L', 3)]

In [546]:
dic3 = {k[:-1]:dic[k] for k in dic if k[-1] == 'A'}
seed3 = helpers.get_seed(dic3)

ground3 = [g for g in ground if g[1] < 3]
L = 3

In [544]:
seed3

[('K', 0), ('Y', 1), ('L', 3)]

In [130]:
dic2 = {k[:-2]:dic[k] for k in dic if k[-1] == 'L' and k[-2] == 'Q'}
seed2 = seed[:-2]
ground2 = [g for g in ground if g[1] < 2]
L = 2

In [535]:
n = 100

def ddr(x, n):
    return n * (n - 2 * x + 1) * (1 - 1 / x) ** n / (x - 1) ** 2 / x ** 2

alpha = scipy.optimize.minimize(ddr, 70, args=(n,), method='Powell')['fun']
alpha = -torch.Tensor(alpha)
alpha

tensor(1.00000e-05 *
       4.1996)

In [557]:
np.random.seed(1)
inds = np.random.choice(60, 30, replace=False)
A = [ground3[i] for i in inds]

In [574]:
%%time

opt = importlib.reload(opt)
objectives = importlib.reload(objectives)
helpers = importlib.reload(helpers)

np.random.seed(101645)
L = 4

X_list, obj_list, perm_list, perms = opt.mod_mod(ground, seed, objectives.dc_lhs, objectives.dc_rhs, 
                               fn_args=(L, dic, alpha), g_args=(L, dic, n, alpha), verbose=True)

Iteration 0	 obj = -0.014445	
Iteration 1	 obj = -0.064004
Iteration 2	 obj = -0.126820
Iteration 3	 obj = -0.152837
Iteration 4	 obj = -0.170050
Iteration 5	 obj = -0.186076
Iteration 6	 obj = -0.196567
Iteration 7	 obj = -0.209288
Iteration 8	 obj = -0.213174
Iteration 9	 obj = -0.219561
Iteration 10	 obj = -0.220467
CPU times: user 47min 3s, sys: 1min 48s, total: 48min 52s
Wall time: 1h 5min 58s


In [568]:
B = X_list[-1]

In [573]:
objectives = importlib.reload(objectives)
opt = importlib.reload(opt)
helpers = importlib.reload(helpers)


_, _, Xs, objs = opt.greedy(ground, seed, objectives.objective, 1, obj_args=(L, dic, n), return_all=True)

for x, obj in zip(Xs, objs):
    print(obj.item(), helpers.get_N(x, L))

-0.014445329159274656 1
-0.02827841545166343 2
-0.04250175502991774 4
-0.06286361429667228 8
-0.09371428426170933 12
-0.12434356196666106 24
-0.15573745360235977 32
-0.18397867407959734 48
-0.20407077304473778 60
-0.21907020755168066 72
-0.2252375682033628 108


In [577]:
%%time

opt = importlib.reload(opt)
objectives = importlib.reload(objectives)
helpers = importlib.reload(helpers)

np.random.seed(101645)
L = 3

X_list, obj_list, perm_list, perms = opt.mod_mod(ground3, seed3, objectives.dc_lhs, objectives.dc_rhs, 
                               fn_args=(L, dic3, alpha), g_args=(L, dic3, n, alpha), verbose=True)

Iteration 0	 obj = -0.006363	
Iteration 1	 obj = -0.017786
Iteration 2	 obj = -0.029355
Iteration 3	 obj = -0.040097
Iteration 4	 obj = -0.047059
Iteration 5	 obj = -0.051602
Iteration 6	 obj = -0.054774
Iteration 7	 obj = -0.058097
Iteration 8	 obj = -0.062154
Iteration 9	 obj = -0.069739
Iteration 10	 obj = -0.073886
Iteration 11	 obj = -0.076899
Iteration 12	 obj = -0.078617
Iteration 13	 obj = -0.080815
Iteration 14	 obj = -0.082132
Iteration 15	 obj = -0.083245
Iteration 16	 obj = -0.084078
Iteration 17	 obj = -0.084872
Iteration 18	 obj = -0.086218
Iteration 19	 obj = -0.087598
Iteration 20	 obj = -0.088953
Iteration 21	 obj = -0.090665
Iteration 22	 obj = -0.091810
Iteration 23	 obj = -0.092703
Iteration 24	 obj = -0.093194
CPU times: user 3min 26s, sys: 11.7 s, total: 3min 38s
Wall time: 4min 28s


In [578]:
objectives = importlib.reload(objectives)
opt = importlib.reload(opt)
helpers = importlib.reload(helpers)


_, _, Xs, objs = opt.greedy(ground3, seed3, objectives.objective, 1, obj_args=(L, dic3, n), return_all=True)

for x, obj in zip(Xs, objs):
    print(obj.item(), helpers.get_N(x, L))

-0.006362772397890626 1
-0.009294416719369591 2
-0.0135901881664697 4
-0.019279320771721082 8
-0.027348733503082978 12
-0.039635361967049296 18
-0.052354384377158565 27
-0.06243370492967245 36
-0.07084752474085987 48
-0.07680222877290716 60
-0.08126210872405523 72
-0.08451755675595073 90
-0.08649325274282776 108
-0.08757385451294562 126
-0.08895308227513692 84
-0.09066500540691927 96
-0.09180968636857108 108
-0.09270319485700912 120
-0.09319378614461538 132


In [514]:
perms = perm_list[-2]
ind = [perm[m - 1] for perm in perms].index(j)
ind

15

In [515]:
len(perms), len(X_list[-1])

(43, 17)

In [501]:
[perm[m - 1] for perm in perms]


[('T', 1),
 ('K', 1),
 ('Y', 0),
 ('G', 0),
 ('D', 0),
 ('V', 0),
 ('Y', 0),
 ('C', 2),
 ('N', 1),
 ('V', 0),
 ('D', 1),
 ('T', 0),
 ('Y', 0),
 ('G', 0),
 ('M', 2),
 ('P', 0),
 ('C', 1),
 ('V', 1),
 ('T', 1),
 ('N', 0),
 ('P', 0),
 ('D', 0),
 ('C', 2),
 ('F', 2),
 ('I', 1),
 ('G', 1),
 ('N', 1),
 ('F', 2),
 ('F', 1),
 ('W', 2),
 ('A', 1),
 ('R', 1),
 ('W', 0),
 ('S', 0),
 ('C', 1),
 ('I', 0),
 ('M', 0),
 ('R', 1),
 ('T', 1),
 ('Q', 1),
 ('L', 2),
 ('K', 1),
 ('T', 0),
 ('W', 0)]

In [525]:
# Check that cant add or take away anything to increase objective

objects = [] # store obj_LHS, obj_RHS, difference btw the two
lib = X_list[-1]
for i in ground:
    if i in lib:
        X_noi = [x for x in lib if x != i]    
        obj_left = objectives.obj_LHS(X_noi, L, dic3)
        obj_right = objectives.obj_RHS(X_noi, L, dic3, n)
        obj = obj_left - obj_right
        
        objects.append((obj_left, obj_right, obj))
    else:        
        obj_left = objectives.obj_LHS(lib + [i], L, dic3)
        obj_right = objectives.obj_RHS(lib + [i], L, dic3, n)
        obj = obj_left - obj_right
        objects.append((obj_left, obj_right, obj))

In [526]:
for obj in objects:
    if obj[2] < objectives.objective(lib, L, dic3, n):
        print(obj)

In [25]:
helpers.get_N(lib, 4)

5600

In [26]:
with open('../outputs/20180828_modmod_not_mono.pkl', 'wb') as f:
    pickle.dump((X_list, obj_list), f)

In [88]:
with open('../outputs/20180828_modmod_not_mono.pkl', 'rb') as f:
    X_list, obj_list = pickle.load(f)

In [116]:
for a in X_list[23]:
    if a not in X_list[24]:
        print(a)
        
for a in X_list[24]:
    if a not in X_list[23]:
        print(a)

('F', 2)


In [113]:
len(X_list[24])

42

In [106]:
inds = [i for i in range(len(objs) - 1) if objs[i] < objs[i + 1]]
inds

[23, 29]

In [109]:
objects = [] # store obj_LHS, obj_RHS, difference btw the two
lib = X_list[29]
n = 100
L = 4
for i in ground:
    if i in lib:
        A = [x for x in lib if x != i]    
    else:        
        A = lib + [i]
    objects.append(objectives.objective(A, L, dic, n))
    
for obj in objects:
    if obj < objectives.objective(lib, 4, dic, 100):
        print(obj)

In [118]:
objectives.objective(X_list[23], L, dic, n)

tensor(1.00000e-02 *
       -1.0326)

In [121]:
%%time

opt = importlib.reload(opt)
objectives = importlib.reload(objectives)
helpers = importlib.reload(helpers)

np.random.seed(10645)

X_list, obj_list, perm_list = opt.mod_mod(ground, seed, objectives.obj_LHS, objectives.obj_RHS, 
                               fn_args=(4, dic), g_args=(4, dic, 100), verbose=True)

Iteration 0	 obj = -0.000105	
Iteration 1	 obj = -0.003231
Iteration 2	 obj = -0.007021
Iteration 3	 obj = -0.008408
Iteration 4	 obj = -0.009028
Iteration 5	 obj = -0.009695
Iteration 6	 obj = -0.009872
Iteration 7	 obj = -0.010028
Iteration 8	 obj = -0.010114
Iteration 9	 obj = -0.010169
Iteration 10	 obj = -0.010197
Iteration 11	 obj = -0.010223
Iteration 12	 obj = -0.010272
Iteration 13	 obj = -0.010291
Iteration 14	 obj = -0.010308
Iteration 15	 obj = -0.010322
Iteration 16	 obj = -0.010334
Iteration 17	 obj = -0.010339
Iteration 18	 obj = -0.010344
Iteration 19	 obj = -0.010347
Iteration 20	 obj = -0.010348
Iteration 21	 obj = -0.010349
Iteration 22	 obj = -0.010350
Iteration 23	 obj = -0.010350
Iteration 24	 obj = -0.010326
Iteration 25	 obj = -0.010357
Iteration 26	 obj = -0.010361
Iteration 27	 obj = -0.010365
Iteration 28	 obj = -0.010368
Iteration 29	 obj = -0.010370
Iteration 30	 obj = -0.010359
Iteration 31	 obj = -0.010377
Iteration 32	 obj = -0.010386
Iteration 33	 obj =