Tests for mod_mod algo

In [1]:
import torch
from torch import distributions as dist

import itertools
import pickle
import importlib

import itertools
import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models
import helpers, opt, objectives

In [2]:
with open('../inputs/gb1.pkl', 'rb') as f:
    t = pickle.load(f)
X = t[0]
A = t[1]
y = t[2]
wt = t[3]

In [3]:
np.random.seed(120120)
_ = torch.manual_seed(43298)

seq_to_x = {}
for i, x in enumerate(X):
    seq = helpers.decode_X(x)
    seq_to_x[seq] = i
wt_inds = [seq_to_x[wt]]

n = 100
train_inds = wt_inds + list(np.random.choice(len(X), n, replace=True))

y_train = y[train_inds]
y_true = y
A_train = A[train_inds]
A_test = A

inds = sorted(set(itertools.chain(train_inds)))


dic, _ = helpers.get_predictions(A[inds], y[inds], A_test,
                                         one_hots=X, its=3000, lr=1e-3)

Iteration 3000 of 3000	NLML: 49.9795	sn: 0.495660	

In [4]:
aas = 'ARNDCQEGHILKMFPSTWYV'
ground = [(aa, i) for aa in aas for i in range(4)]

In [5]:
seen_seqs = [helpers.decode_X(X[i]) for i in inds]
for s in seen_seqs:
    dic[s] = 0.0

In [6]:
seed = helpers.get_seed(dic)
seed

[('W', 0), ('N', 1), ('Q', 2), ('W', 3)]

In [7]:
with open('../outputs/20180828_modmod_not_mono.pkl', 'rb') as f:
    X_list, obj_list = pickle.load(f)

In [34]:
lib1, lib2 = X_list[23:25]
n = 100
L = 4
diff = list(set(lib2) - set(lib1))

def v(A):
    return objectives.objective(A, L, dic, n)

def g(A):
    return objectives.obj_RHS(A, L, dic, n)

In [17]:
g(lib1), g(lib2)

(tensor(-0.5816), tensor(-0.7754))

In [40]:
np.random.seed(0)
inds = opt._make_permuted_indices(ground, lib1)
perms = [[ground[i] for i in p] for p in inds]

In [43]:
uppers = []
fn = objectives.obj_LHS
fn_args = (L, dic)
fn_kwargs = {}
for i in ground:
    if i in lib1:
        X_noi = [x for x in lib1 if x != i]
        uppers.append(opt.mod_upper(X_noi, fn, lib1, ground, *fn_args, **fn_kwargs))
    else:
        uppers.append(opt.mod_upper(lib1 + [i], fn, lib1, ground, *fn_args, **fn_kwargs))


In [52]:
%%time
g_args = (L, dic, n)
g_kwargs = {}
candidates = [opt._get_candidates(perm, ground, lib1, uppers, obj_list[23],
                                 objectives.obj_RHS, g_args, g_kwargs)
             for perm in perms]
candidates = list(itertools.chain.from_iterable(candidates))

CPU times: user 11min 43s, sys: 10.8 s, total: 11min 53s
Wall time: 13min 26s


In [54]:
candidates_ = [c[0] for c in candidates if c[1]]
candidates_

[[('W', 0),
  ('C', 0),
  ('I', 0),
  ('P', 0),
  ('R', 0),
  ('L', 0),
  ('Q', 0),
  ('G', 0),
  ('Y', 0),
  ('E', 0),
  ('K', 0),
  ('F', 0),
  ('N', 0),
  ('V', 0),
  ('N', 1),
  ('R', 1),
  ('G', 1),
  ('F', 1),
  ('S', 1),
  ('K', 1),
  ('D', 1),
  ('C', 1),
  ('L', 1),
  ('Q', 2),
  ('G', 2),
  ('V', 2),
  ('W', 3),
  ('R', 3),
  ('E', 3),
  ('K', 3),
  ('C', 3),
  ('H', 3),
  ('I', 3),
  ('Y', 3),
  ('L', 3),
  ('F', 3),
  ('S', 3),
  ('G', 3),
  ('A', 3),
  ('P', 3),
  ('M', 3),
  ('D', 2)],
 [('W', 0),
  ('C', 0),
  ('I', 0),
  ('P', 0),
  ('R', 0),
  ('L', 0),
  ('Q', 0),
  ('G', 0),
  ('Y', 0),
  ('E', 0),
  ('K', 0),
  ('F', 0),
  ('N', 0),
  ('V', 0),
  ('N', 1),
  ('R', 1),
  ('G', 1),
  ('F', 1),
  ('S', 1),
  ('K', 1),
  ('D', 1),
  ('C', 1),
  ('L', 1),
  ('Q', 2),
  ('G', 2),
  ('V', 2),
  ('W', 3),
  ('R', 3),
  ('E', 3),
  ('K', 3),
  ('C', 3),
  ('H', 3),
  ('I', 3),
  ('Y', 3),
  ('L', 3),
  ('F', 3),
  ('S', 3),
  ('G', 3),
  ('A', 3),
  ('P', 3),
  ('M', 3),
  (

In [56]:
for c in candidates_:
    print(objectives.objective(c, L, dic, n))

tensor(1.00000e-02 *
       -1.0262)
tensor(1.00000e-02 *
       -1.0260)
tensor(1.00000e-02 *
       -1.0188)
tensor(1.00000e-02 *
       -1.0262)


In [57]:
for i, c in enumerate(candidates):
    if c[1]:
        print(i)

27
37
57
73


In [105]:
perm = perms[18]
lib3 = candidates_[1]
added = []
removed = []
print('added')
for c in lib3:
    if c not in lib1:
        print(c)
        added.append(c)
print('removed')
for c in lib1:
    if c not in lib3:
        print(c)
        removed.append(c)

added
('D', 2)
('F', 2)
removed


In [109]:
perm[41]

('I', 2)

In [110]:
h = opt.mod_lower(lib3, objectives.obj_RHS, perm, *g_args, **g_kwargs)
m = opt.mod_upper(lib3, objectives.obj_LHS, lib1, ground, *fn_args)
(m[1] - h).item()

-0.010666966438293457

In [83]:
obj_list[23].item()

-0.01034998893737793

In [111]:
objectives.objective(lib3, L, dic, n).item()

-0.010260343551635742

In [112]:
objectives.obj_LHS(lib3, L, dic).item()

-0.9746916890144348

In [117]:
m[1].item()

-0.9746916890144348

In [114]:
objectives.obj_RHS(lib3, L, dic, n).item()

-0.9644313454627991

In [115]:
h.item()

-0.9640247225761414

In [92]:
# The lower bound fails for some permutations...
# Does changing items in the permutation at the end effect anything? Should it?

In [103]:
perm[42]

('T', 2)

In [98]:
perm[:41]

41

In [118]:
B = tuple(sorted(lib1))
for p in perms:
    A = tuple(sorted(p[:41]))
    assert A == B

In [125]:
A = perms[13][:41]
B = perms[13][:42]
x = perms[13][42]
del1 = objectives.obj_RHS(A + [x], L, dic, n) - objectives.obj_RHS(A, L, dic, n)
del2 = objectives.obj_RHS(B + [x], L, dic, n) - objectives.obj_RHS(B, L, dic, n)
del1, del2

(tensor(-0.1890), tensor(-0.1888))

In [126]:
helpers.get_N(A, L)

5670

In [131]:
torch.randn(size=(10, 4))

tensor([[-0.7386,  0.5874,  0.7542, -0.5349],
        [ 0.1037, -0.2985, -0.5463,  1.0685],
        [ 0.1611,  1.6024, -0.1498,  0.5626],
        [-0.7317,  1.0714, -0.7512, -1.0312],
        [-0.6830,  0.3442,  0.5877,  1.4528],
        [-0.2115, -1.2897,  1.4389, -0.2047],
        [ 0.3044, -0.2894,  0.7601, -0.1009],
        [ 0.6436,  0.6820,  0.6647,  1.2048],
        [ 0.0238,  0.1157,  0.5046, -0.0459],
        [-0.2615, -0.4721,  2.2720,  0.0629]])