Implements algorithm 3 (ModMod) from Algorithms for "Approx Min of the Difference Between Submodular Fncs with Applications."

For prob_ssm

In [1]:
import torch
from torch import distributions as dist

import itertools
import pickle
import operator

import random
import math
import numpy as np
import pandas as pd
import scipy

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models

In [2]:
def mod_lower(X, fn, perm, *args, **kwargs):
    """ Modular lower bound of fn(X) for any X contained in ground set V
    with permutation chain perm (aka S).
    
    Expects X as a list of tuples, fn as a Python function, and perm as a list. """
    
    low = 0.0 # lower modular bound
    
    for elem in X:
        i = perm.index(elem)
        if i == 0:
            low += fn([perm[0]], *args, **kwargs)
        else:
            low += fn(perm[0:i + 1], *args, **kwargs) - fn(perm[0:i], *args, **kwargs)
            
    return low

In [3]:
def mod_upper(X, fn, center, *args, **kwargs):
    """ Modular upper bound of fn(X) for any X contained in ground set V, 
    centered at center.
    
    Expects X and center as lists of tuples, and fn as a Python function. """
    
    up = fn(center, *args, **kwargs) # modular upper bound

    for j in center:
        if j not in X:
            center_noj = [x for x in center if x != j]
            up -= fn(center, *args, **kwargs) - fn(center_noj, *args, **kwargs)

#     for j in X:
#         if j not in center:
#             up += fn([j], *args, **kwargs) - fn([], *args, **kwargs)
            
    return up

In [4]:
def make_perm(V, X):
    """ Takes in the ground set V and a set X, and
    returns a random chain permutation that contains X """
    if len(X) == 0:
        indices = list(range(len(V)))
        random.shuffle(indices)
        return [V[i] for i in indices]
    
    ind_X = [i for i, v in enumerate(V) if v in X] # indices of X in V
    rest = [i for i in list(range(len(V))) if i not in ind_X] # rest of indices in V
    
    random.shuffle(ind_X) # shuffle indices
    random.shuffle(rest)
    indices = ind_X + rest # combine
    
    return [V[i] for i in indices] # generate perm based on shuffled indices

In [5]:
def mod_mod(V, fn, g, seed, *args, **kwargs):
    """ Implements algorithm3 (ModMod) from paper. Takes in ground set V,
    and functions fn and g. 
    
    Expects V as a list of tuples, and fn and g as submodular Python functions. """
    
    X = seed
    obj_lst = [] # stores objectives at each time step
    
    while True:
        X_next = X
        perm = make_perm(V, X) # choose permutation
        
        up = mod_upper(X, fn, X, args[0])
        low = mod_lower(X, g, perm, *args)
        print((up, low))
        emp = up - low #mod_upper(X, fn, X, args[0]) - mod_lower(X, g, perm, *args) # obj w X as input
        obj_lst.append(emp)
        print(emp)
        for i in V:
            if i in X: 
                X_noi = [x for x in X if x != i]
                obj = mod_upper(X_noi, fn, X, args[0]) - mod_lower(X_noi, g, perm, *args) # obj w X w/o element in V as input
                print((i, obj))
                if obj < emp:
                    X_next.remove(i)
            else:
                obj = mod_upper(X + [i], fn, X, args[0]) - mod_lower(X + [i], g, perm, *args) # obj w element in V added to X as input
                print((i, obj))
                if obj < emp:
                    X_next.append(i)
                    
        if X_next == X:
            break
        else:
            X = X_next
    
    return X_next, obj_lst

In [6]:
# [['A', 'R'], ['V'], ['P'], ['F', 'Q']] - turn into list of lists first!

# lst = [('F', 3), ('V', 1), ('P', 2), ('A', 0), ('R', 0), ('Q', 3), ]

In [7]:
with open('../inputs/phoq.pkl', 'rb') as f:
    t = pickle.load(f)

X = t[0] # one-hot encoding of X
T = t[1] # tokenized encoding of X
y = t[2].values

In [8]:
def decode_X(X):
    """ Takes in one-hot encoding X and decodes it to
    return a string of four amino acids. """
    
    amino_acids = 'ARNDCQEGHILKMFPSTWYV'
    
    pos_X = [i for i, x in enumerate(X) if x == 1.0] # positions of amino acids
    pos_X = [(p - 20 * i) for i, p in enumerate(pos_X)] # make sure indexing is same as in str amino_acids
    aa_X = [amino_acids[p] for i, p in enumerate(pos_X)] # amino acid chars in X
    return ''.join(aa_X)

In [9]:
def get_predictions(X_train, y_train, X_test, its=500):
    """
    Train GP regressor on X_train and y_train. 
    Predict mean and std for X_test. 
    Return P(y > y_train_max) as dictionary eg 'AGHU': 0.78
    NB: for X_test in X_train, P ~= 0
    Be careful with normalization
    
    Expects X_train, y_train, and X_test as np.arrays
    """
    
    ke = kernels.MaternKernel()
    mo = models.GPRegressor(ke)
    
    # make data into tensors
    X_train = torch.Tensor(X_train)
    X_test = torch.Tensor(np.array(X_test))
    y_train_scaled = (np.array(y_train) - np.mean(np.array(y_train))) / np.std(np.array(y_train)) # scale y_train
    y_train_scaled = torch.Tensor(y_train_scaled.reshape(len(y_train_scaled), 1)).double() # .float()
    
    his = mo.fit(X_train, y_train_scaled, its=its) # fit model with training set
    
    # make predictions
    dic = {} # use dictionary to store probs
    ind = 0 # index for feeding in batches of X_test
    tau = y_train_scaled.max().float()
    
    for i in range(1000, len(X) + 1000, 1000):
        mu, var = mo.forward(X_test[ind:i]) # make predictions
        std = torch.sqrt(var.diag())
        mu = mu.squeeze()
        prob = 1 - dist.Normal(mu, std).cdf(tau) # compute probabilities for all means, stds

        for j, p in enumerate(prob):
            seq = decode_X(X_test[ind:i][j]) # decode one-hot to get string of seq
            dic[seq] = p # store prob for each seq

        ind = i
        
    return dic

np.random.seed(1)
rand_inds = np.random.choice(len(X), 100, replace=True) # generate random indices for 100 X's to sample from
X_train = X[rand_inds]
y_train = y[rand_inds]
X_test = X
y_true = y

dic = get_predictions(X_train, y_train, X_test, its=500)

Iteration 500 of 500	NLML: 40.0319	

In [10]:
# Computing objective LHS and RHS (two supermodular set functions)

def obj_LHS(X, probs):
    """ Takes in library X, and probabilities.
    
    Expects X to be a list of tuples, and probs to be a dictionary.
    
    Returns LHS of objective to be maximized (a supermodular function):
    sum of probabilities. """
    
    # if X empty or does not have aa at each position, return 0
    if len([i for i in range(4) if i not in [tup[1] for tup in X]]) > 0:
        return torch.Tensor([0.0])[0]
    
    # filter thru probs to find prob of x's in X
    X.sort(key=lambda tup: tup[1])

    X_str = [[tup[0] for i, tup in enumerate(X) if tup[1] == j] for j in range(4)] # generate list of lists of strings
    X_str = [''.join(s) for s in itertools.product(*X_str)] # generate list of strings of 4 aa seqs

    p = torch.Tensor([probs[key] for key in X_str])
    
    return -1 * torch.sum(p)

def obj_RHS(X, probs, n):
    """ Takes in library X, probabilities, and batch size n.
    
    Expects X to be a list of tuples, and probs to be a dictionary.
    
    Returns RHS of objective to be maximized (a supermodular function):
    sum of probabilities times expression with N and n. """
    
    # if X empty or does not have aa at each position, return 0
    if len([i for i in range(4) if i not in [tup[1] for tup in X]]) > 0:
        return torch.Tensor([0.0])[0]
    
    N = 1 # represents the product of sequence of # aas at each position
    for i in X:
        N *= len(i)
    
    # filter thru probs to find prob of x's in X
    X.sort(key=lambda tup: tup[1])

    X_str = [[tup[0] for i, tup in enumerate(X) if tup[1] == j] for j in range(4)] # generate list of lists of strings
    X_str = [''.join(s) for s in itertools.product(*X_str)] # generate list of strings of 4 aa seqs

    p = torch.Tensor([probs[key] for key in X_str])
    obj = torch.sum(p) * (1 - 1 / N) ** n
    
    return -1 * obj

In [11]:
def generate_V():
    """ Returns V: a list of tuples with every possible amino acid
    at each of the four positions. """
    
    amino_acids = 'ARNDCQEGHILKMFPSTWYV'
    return [(aa, i) for i in range(4) for aa in amino_acids]

V = generate_V()

In [12]:
X = [('F', 3), ('V', 1), ('P', 2), ('A', 0), ('R', 0), ('Q', 3)]
center = [('F', 3), ('E', 0), ('N', 1), ('G', 2)]#, ('M', 3), ('S', 2), ('W', 2)]

In [13]:
def get_seed(probs):
    """ Takes in a dictionary of amino acids to probabilities as
    generated by the get_predictions() function, and returns the 
    seed (the four amino acid seq with the best prediction, aka the 
    highest probabilitiy). 
    
    Returns a list of tuples representing the seed.
    
    Currently, 'SSSG' is the seed. """
    
    seq = max(probs.items(), key=operator.itemgetter(1))[0]
    return [(aa, i) for aa, i in zip(seq, range(4))]

seed = get_seed(dic)
seed

[('S', 0), ('S', 1), ('S', 2), ('L', 3)]

In [14]:
perm = make_perm(V, X)

In [15]:
obj_LHS(X, dic)

tensor(1.00000e-04 *
       -1.5783)

In [16]:
mod_upper(X, obj_LHS, center, dic) ## should match

tensor(1.00000e-04 *
       2.7990)

In [17]:
A = torch.stack([mod_upper([x], obj_LHS, center, dic) for x in X]).sum()
A -= (len(X) - 1) * mod_upper([], obj_LHS, center, dic) # takes into account empty set ## should match
A

tensor(1.00000e-04 *
       2.7990)

In [18]:
mod_lower(X, obj_LHS, perm, dic)

tensor(1.00000e-04 *
       -1.5783)

In [19]:
print(torch.stack([mod_lower([x], obj_LHS, perm, dic) for x in X]).sum()) # lower bound

tensor(1.00000e-04 *
       -1.5783)


In [20]:
obj_RHS(seed, dic, 100)

tensor(1.00000e-04 *
       -3.8924)

In [21]:
obj_RHS(X, dic, 100)

tensor(1.00000e-05 *
       -3.2678)

In [22]:
mod_upper(X, obj_RHS, center, dic, 100)

tensor(1.00000e-07 *
       4.4069)

In [23]:
B = torch.stack([mod_upper([x], obj_RHS, center, dic, 100) for x in X]).sum()
B -= (len(X) - 1) * mod_upper([], obj_RHS, center, dic, 100) # takes into account empty set ## should match
B

tensor(1.00000e-07 *
       4.4069)

In [24]:
mod_lower(X, obj_RHS, perm, dic, 100)

tensor(1.00000e-05 *
       -3.2678)

In [25]:
print(torch.stack([mod_lower([x], obj_LHS, perm, dic) for x in X]).sum()) # lower bound

tensor(1.00000e-04 *
       -1.5783)


In [37]:
seed = np.random.choice(np.array([key for key in dic]))
seed

'VQES'

In [39]:
seed1 = [(s, i) for i, s in enumerate(seed)]
seed1

[('V', 0), ('Q', 1), ('E', 2), ('S', 3)]

In [40]:
mod_mod(V, obj_LHS, obj_RHS, seed1, dic, 100)

(tensor(1.00000e-04 *
       -5.1570), tensor(1.00000e-07 *
       -8.1194))
tensor(1.00000e-04 *
       -5.1489)
(('A', 0), tensor(1.1678))
(('R', 0), tensor(0.2610))
(('N', 0), tensor(1.4535))
(('D', 0), tensor(1.00000e-03 *
       2.3682))
(('C', 0), tensor(1.00000e-02 *
       1.1784))
(('Q', 0), tensor(5.3437))
(('E', 0), tensor(1.00000e-03 *
       2.4020))
(('G', 0), tensor(1.1790))
(('H', 0), tensor(0.4103))
(('I', 0), tensor(3.0005))
(('L', 0), tensor(1.0901))
(('K', 0), tensor(0.5654))
(('M', 0), tensor(2.0035))
(('F', 0), tensor(0.6915))
(('P', 0), tensor(0.2170))
(('S', 0), tensor(11.4290))
(('T', 0), tensor(1.7659))
(('W', 0), tensor(1.00000e-02 *
       8.9663))
(('Y', 0), tensor(0.2399))
(('V', 0), tensor(1.00000e-07 *
       8.1194))
(('A', 1), tensor(25.4847))
(('R', 1), tensor(2.3866))
(('N', 1), tensor(1.00000e-03 *
       3.8084))
(('D', 1), tensor(7.3930))
(('C', 1), tensor(1.9832))
(('Q', 1), tensor(1.00000e-07 *
       8.1194))
(('E', 1), tensor(1.3598))
(('G', 1

([('V', 0), ('Q', 1), ('E', 2), ('S', 3)], [tensor(1.00000e-04 *
         -5.1489)])

In [None]:
seed = [('S', 0), ('S', 1), ('S', 2), ('L', 3)]

In [None]:
mod_upper(seed, obj_LHS, seed, dic) - mod_lower(seed, obj_RHS, perm, dic, 100)

In [None]:
mod_upper(seed, obj_LHS, seed, dic)

In [None]:
mod_lower(seed, obj_RHS, perm, dic, 100)