Tests for mod_mod algo

In [2]:
import torch
from torch import distributions as dist

import itertools
import pickle
import importlib

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models
import helpers, opt, objectives

In [3]:
torch.set_default_tensor_type(torch.DoubleTensor)

In [4]:
with open('../inputs/phoq.pkl', 'rb') as f:
    t = pickle.load(f)
X = t[0]
A = t[2]
y = t[3].values
wt = helpers.decode_X(X[150614])  # wt as string
aas = 'ARNDCQEGHILKMFPSTWYV'

In [5]:
seq_to_x = {}
for i, x in enumerate(X):
    seq = helpers.decode_X(x)
    seq_to_x[seq] = i
wt_inds = [seq_to_x[wt]]

In [6]:
np.random.seed(120120)
_ = torch.manual_seed(43298)



n = 100

singles = [wt[:i] + aas[j] + wt[i+1:] for i in range(4) for j in range(20)]
wt_inds = [seq_to_x[wt]]
single_inds = wt_inds + [seq_to_x[seq] for seq in singles]
train_inds = wt_inds + single_inds + list(np.random.choice(len(X), n, replace=True))


y_train = y[train_inds]
y_true = y
A_train = A[train_inds]
A_test = A

inds = sorted(set(itertools.chain(train_inds)))


dic, _ = helpers.get_predictions(A[inds], y[inds], A_test,
                                         one_hots=X, its=3000, lr=1e-2)

Iteration 3000 of 3000	NLML: 58.9343	sn: 0.259592	

In [7]:
aas = 'ARNDCQEGHILKMFPSTWYV'
ground = [(aa, i) for aa in aas for i in range(4)]

In [8]:
seen_seqs = [helpers.decode_X(X[i]) for i in inds]
for s in seen_seqs:
    dic[s] = 0.0

In [9]:
seed = helpers.get_seed(dic)
seed

[('A', 0), ('Q', 1), ('S', 2), ('V', 3)]

In [10]:
rand_dict = {k:torch.tensor(np.random.random()) for k in dic}

In [11]:
dic3 = {k[:-1]:dic[k] for k in dic if k[-1] == seed[-1][0]}
rand3 = {k[:-1]:rand_dict[k] for k in dic if k[-1] == seed[-1][0]}
seed3 = helpers.get_seed(dic3)

ground3 = [g for g in ground if g[1] < 3]
L = 3

In [12]:
seed3

[('A', 0), ('Q', 1), ('S', 2)]

In [13]:
dic2 = {k[:-2]:dic[k] for k in dic if k[-1] == seed[-1][0] and k[-2] == seed[-2][0]}
rand2 = {k[:-2]:rand_dict[k] for k in dic if k[-1] == seed[-1][0] and k[-2] == seed[-2][0]}

seed2 = seed[:-2]
ground2 = [g for g in ground if g[1] < 2]
L = 2

In [14]:
dic1 = {k[0]:dic[k] for k in dic if k[-1] == seed[-1][0] and k[-2] == seed[-2][0] and k[-3] == seed[-3][0]}
seed1 = seed[:-3]
ground1 = [g for g in ground if g[1] < 1]
L = 1

In [15]:
n = 100

def ddr(x, n):
    return n * (n - 2 * x + 1) * (1 - 1 / x) ** n / (x - 1) ** 2 / x ** 2

alpha = scipy.optimize.minimize(ddr, 70, args=(n,), method='Powell')['fun']
alpha = -torch.Tensor(alpha)
alpha

tensor(1.00000e-05 *
       4.1996)

In [16]:
np.random.seed(1)
inds = np.random.choice(80, 30, replace=False)
A = [ground[i] for i in inds]

In [56]:
def get_alpha(dic, n):
    probs = np.sort(np.array([dic[k] for k in dic]))
    L = len(next(iter(dic.keys())))
    S = len(probs)
    alpha = 0
    for s in range(1, S):
        if s % 1000 == 0:
            print(s)
        candidate = (1 - (1 - 1 / (2 * s)) ** n) - (1 - (1 - 1 / s) ** n)
        candidate *= np.sum(probs[:S])
        if candidate <= alpha:
            alpha = candidate
        else:
            break
    return alpha

    
%time alpha4 = -get_alpha(dic, 100)
alpha4

CPU times: user 54.3 s, sys: 884 ms, total: 55.2 s
Wall time: 56.2 s


tensor(138.8483)

In [18]:
%%time

opt = importlib.reload(opt)
objectives = importlib.reload(objectives)
helpers = importlib.reload(helpers)

np.random.seed(101645)
L = 3
S = L * 20
S = torch.tensor(S).double()
beta = 2 * torch.sqrt(S - 1) - torch.sqrt(S) - torch.sqrt(S - 2)

X_list, obj_list, perm_list, perms = opt.mod_mod(ground3, seed3, objectives.objective,
                                                 args=(L, dic3, n), dec='dc', alpha=alpha, beta=beta)

Iteration 0	 obj = -0.171235	
Iteration 1	 obj = -0.356971
Iteration 2	 obj = -0.604633
Iteration 3	 obj = -0.967881
Iteration 4	 obj = -1.342621
Iteration 5	 obj = -1.785980
Iteration 6	 obj = -2.313732
Iteration 7	 obj = -2.831154
Iteration 8	 obj = -3.376843
Iteration 9	 obj = -3.904830
Iteration 10	 obj = -4.231414
Iteration 11	 obj = -4.493796
Iteration 12	 obj = -4.654944
Iteration 13	 obj = -4.756065
Iteration 14	 obj = -4.806947
CPU times: user 2min 4s, sys: 6.95 s, total: 2min 11s
Wall time: 2min 40s


In [19]:
objectives = importlib.reload(objectives)
opt = importlib.reload(opt)
helpers = importlib.reload(helpers)
L = 3

_, _, Xs, objs = opt.greedy(ground3, seed3, objectives.objective, 1, obj_args=(L, dic3, n), return_all=True)

for x, obj in zip(Xs, objs):
    print(obj.item(), helpers.get_N(x, L))

-0.17123523886021352 1
-0.31963582552666137 2
-0.5367199181448924 4
-0.8720610205288681 8
-1.2560749694515705 12
-1.8016627814520811 18
-2.421484947267848 27
-3.0203766049551466 36
-3.5691597040610263 48
-3.9749006426680755 60
-4.270636997051267 75
-4.493796290114859 90
-4.654944350682018 105
-4.756065354676568 126
-4.806946572619794 144


In [20]:
for x in X_list:
    _, obj, _, _ = opt.greedy(ground3, x, objectives.objective, 1, obj_args=(L, dic3, n), return_all=True)
    print(obj)

tensor(-4.8069)
tensor(-4.6988)
tensor(-4.6988)
tensor(-4.6988)
tensor(-4.6988)
tensor(-4.6988)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)
tensor(-4.8069)


In [61]:
opt = importlib.reload(opt)

L = 4
S = L * 20
S = torch.tensor(S).double()
beta = 2 * torch.sqrt(S - 1) - torch.sqrt(S) - torch.sqrt(S - 2)

ss_list, objs = opt.supsub(ground, seed, objectives.objective, args=(L, dic, n),
                           dec='ds', alpha=alpha4, beta=beta, verbose=True)

Iteration 0	 obj = -0.171235	
Iteration 0	 obj = -0.330886	
Iteration 0	 obj = -0.619065	
Iteration 0	 obj = -1.056581	
Iteration 0	 obj = -1.716390	
Iteration 0	 obj = -2.465870	
Iteration 0	 obj = -3.365414	
Iteration 0	 obj = -4.189944	
Iteration 0	 obj = -4.860962	
Iteration 0	 obj = -5.303563	
Iteration 0	 obj = -5.595810	
Iteration 0	 obj = -5.708489	
Iteration 0	 obj = -5.778487	
Iteration 0	 obj = -5.808615	


In [70]:
%%time
opt = importlib.reload(opt)

L = 4
S = L * 20
S = torch.tensor(S).double()
beta = 2 * torch.sqrt(S - 1) - torch.sqrt(S) - torch.sqrt(S - 2)

ss_list2, objs = opt.supsub(ground, seed, objectives.objective, args=(L, dic, n),
                           dec='dc', alpha=alpha, beta=beta, verbose=True)

Iteration 0	 obj = -0.171235	
tensor(-0.3663) tensor(-3.2155)
Iteration 1	 obj = -3.215478	
tensor(-0.2989) tensor(-5.0284)
Iteration 2	 obj = -5.028404	
tensor(-0.5408) tensor(-5.2698)
Iteration 3	 obj = -5.269850	
tensor(-0.4517) tensor(-5.4225)
Iteration 4	 obj = -5.422475	
tensor(-0.3096) tensor(-5.4246)
Iteration 5	 obj = -5.424572	
tensor(-0.3847) tensor(-5.5689)
Iteration 6	 obj = -5.568863	
tensor(-0.3157) tensor(-5.6618)
Iteration 7	 obj = -5.661762	
tensor(-0.3746) tensor(-5.7157)
Iteration 8	 obj = -5.715697	
tensor(-0.3717) tensor(-5.7391)
Iteration 9	 obj = -5.739143	
tensor(-0.2755) tensor(-5.7492)
Iteration 10	 obj = -5.749152	
tensor(-0.4279) tensor(-5.7727)
Iteration 11	 obj = -5.772691	
tensor(-0.2682) tensor(-5.7727)
CPU times: user 5h 28min 30s, sys: 11min 24s, total: 5h 39min 54s
Wall time: 5h 41min 53s


In [67]:
opt = importlib.reload(opt)

Y, ob = opt.seeded_stochastic_usm(ground, seed, objectives.objective, obj_args=(L, dic, n))
ob

tensor(1.00000e-02 *
       -4.6548)

In [64]:
obj_list

[tensor(-0.1712),
 tensor(-0.3309),
 tensor(-0.6191),
 tensor(-1.0566),
 tensor(-1.7164),
 tensor(-2.4659),
 tensor(-3.3654),
 tensor(-4.1899),
 tensor(-4.8610),
 tensor(-5.3036),
 tensor(-5.5958),
 tensor(-5.7085),
 tensor(-5.7785),
 tensor(-5.8086)]

In [33]:
%%time
np.random.seed(101)
best = torch.tensor(0.0)
L = 4
for _ in range(100):
    M = np.random.random() * 76 + 4
    M = int(M)
    print(M)
    inds = np.random.choice(80, M, replace=False)
    A = [ground[i] for i in inds]
    _, obj, _, _ = opt.greedy(ground, A, objectives.objective, 1, obj_args=(L, dic, n), return_all=True)
    if obj < best:
        best = obj
        print(best)
    

43
tensor(-5.7845)
39
33
39
75
51
49
33
tensor(-5.8086)
56
11
tensor(-5.8086)
7
52
17
41
33
5
47
45
38
42
47
7
tensor(-5.8086)
39
55
22
20
78
25
22
54
73
53
48
36
11
50
31
19
61
29
76
65
21
5
73
5
tensor(-5.8086)
79
12
29
79
53
15
21
72
51
14
79
78
74
33
47
22
54
77
9
17
6
46
26
64
23
47
51
7
17
32
71
58
32
77
20
76
71
14
56
37
51
5
54
35
65
13
6
12
54
40
13
17
30
64
CPU times: user 58min 54s, sys: 1min 2s, total: 59min 56s
Wall time: 59min 59s


In [302]:
objectives = importlib.reload(objectives)
opt = importlib.reload(opt)
helpers = importlib.reload(helpers)
L = 3

_, _, Xs, objs = opt.greedy(ground3, seed3, objectives.objective, 1, obj_args=(L, dic3, n), return_all=True)

for x, obj in zip(Xs, objs):
    print(obj.item(), helpers.get_N(x, L))

-0.17123523886021352 1
-0.31963582552666137 2
-0.5367199181448924 4
-0.8720610205288681 8
-1.2560749694515705 12
-1.8016627814520811 18
-2.421484947267848 27
-3.0203766049551466 36
-3.5691597040610263 48
-3.9749006426680755 60
-4.270636997051267 75
-4.493796290114859 90
-4.654944350682018 105
-4.756065354676568 126
-4.806946572619794 144
