First, makes predictions on a library of 160,000 PhoQ variants using GP and Matern Kernel then computes objective. Combines gp_ssm and objective_ssm notebooks. 

Includes functions that compute each of the three baselines:
1. Baseline that creates optimal sequence from X's given optimal amino acids (those with max y-values) at each position out of the four possible positions in the wildtype sequence by fixing the three other positions, then continues onto the next position by fixing the best amino acid in the previous position.
2. Baseline that creates optimal sequence from X's given optimal amino acids (those with max y-values) at each position out of the four possible positions in the wildtype sequence by fixing the three other positions, then takes the best amino acid at each position.
3. Baseline that uses greedy algorithm to maximize objective. Starts out with best prediction then continues to add amino acids until objective stops increasing.

In [None]:
import torch
from torch import distributions as dist
import itertools
import pickle

from scipy.stats import norm
import operator

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set_style('white')
sns.set_context('paper')
# Plot adjustments:
plt.rcParams.update({'ytick.labelsize': 15})
plt.rcParams.update({'xtick.labelsize': 15})
plt.rcParams.update({'axes.labelsize': 35})
plt.rcParams.update({'legend.fontsize': 30})
plt.rcParams.update({'axes.titlesize': 16})

from gptorch import kernels, models

In [None]:
with open('../inputs/phoq.pkl', 'rb') as f:
    t = pickle.load(f)

X = t[0] # one-hot encoding of X
T = t[1] # tokenized encoding of X
y = t[2].values

In [None]:
def decode_X(X):
    """ Takes in one-hot encoding X and decodes it to
    return a string of four amino acids. """
    
    amino_acids = 'ARNDCQEGHILKMFPSTWYV'
    
    pos_X = [i for i, x in enumerate(X) if x == 1.0] # positions of amino acids
    pos_X = [(p - 20 * i) for i, p in enumerate(pos_X)] # make sure indexing is same as in str amino_acids
    aa_X = [amino_acids[p] for i, p in enumerate(pos_X)] # amino acid chars in X
    return ''.join(aa_X)

# test on AWSS
# decode_X([ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
#         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#         0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
#         0.,  0.])

In [None]:
## UPDATED VERSION OF GP_train() method

def get_predictions(X_train, y_train, X_test, its=500, *args, **kwargs):
    """
    Train GP regressor on X_train and y_train. 
    Predict mean and std for X_test. 
    Return P(y > y_train_max) as dictionary eg 'AGHU': 0.78
    NB: for X_test in X_train, P ~= 0
    Be careful with normalization
    
    Expects X_train, y_train, and X_test as np.arrays
    """
    
    ke = kernels.MaternKernel()
    mo = models.GPRegressor(ke)
    
    # make data into tensors
    X_train = torch.Tensor(X_train)
    X_test = torch.Tensor(np.array(X_test))
    y_train_scaled = (np.array(y_train) - np.mean(np.array(y_train))) / np.std(np.array(y_train)) # scale y_train
    y_train_scaled = torch.Tensor(y_train_scaled.reshape(len(y_train_scaled), 1))
    
    his = mo.fit(X_train, y_train_scaled, its=its, *args, **kwargs) # fit model with training set
    
    # make predictions
    dic = {} # use dictionary to store probs
    ind = 0 # index for feeding in batches of X_test
    tau = y_train_scaled.max().float()
    
    means = []
    for i in range(1000, len(X) + 1000, 1000):
        mu, var = mo.forward(X_test[ind:i]) # make predictions
        std = torch.sqrt(var.diag())
        mu = mu.squeeze()
        prob = 1 - dist.Normal(mu, std).cdf(tau) # compute probabilities for all means, stds

        for j, p in enumerate(prob):
            seq = decode_X(X_test[ind:i][j])
            dic[seq] = p # store prob for each seq
            means.append(mu[j])

        ind = i
        
    return dic, means

In [None]:
np.random.seed(1)
rand_inds = np.random.choice(len(X), 100, replace=True) # generate random indices for 100 X's to sample from
X_train = X[rand_inds]
y_train = y[rand_inds]
X_test = X
y_true = y

dic, means = get_predictions(X_train, y_train, X_test, its=500)

preds = [] # to keep track of predictions after each iteration through greedy algorithm
preds.append(means)

In [None]:
# # Computing objective

# def objective1(X, probs, n):
#     """ Takes in library X, probabilities, and batch size n.
    
#     Expects X to be a list of tuples, and probs to be a dictionary.
    
#     Returns objective to be maximized. """
    
#     N = 1 # represents the product of sequence of # aas at each position
#     for i in X:
#         N *= len(i)
    
#     # filter thru probs to find prob of x's in X
#     X.sort(key=lambda tup: tup[1])

#     X_str = [[tup[0] for i, tup in enumerate(X) if tup[1] == j] for j in range(4)] # generate list of lists of strings
#     X_str = [''.join(s) for s in itertools.product(*X_str)] # generate list of strings of 4 aa seqs

#     p = torch.Tensor([probs[key] for key in X_str])
#     obj = torch.sum(p) * (1 - (1 - 1 / N) ** n)
    
#     return obj

In [None]:
# Computing objective LHS and RHS (two supermodular set functions) - objective to be minimized

def obj_LHS(X, probs):
    """ Takes in library X, and probabilities.
    
    Expects X to be a list of tuples, and probs to be a dictionary.
    
    Returns LHS of objective to be maximized (a supermodular function):
    sum of probabilities. """
    
    # if X empty or does not have aa at each position, return 0
    if len([i for i in range(4) if i not in [tup[1] for tup in X]]) > 0:
        return torch.Tensor([0.0])[0]
    
    # filter thru probs to find prob of x's in X
    X.sort(key=lambda tup: tup[1])

    X_str = [[tup[0] for i, tup in enumerate(X) if tup[1] == j] for j in range(4)] # generate list of lists of strings
    X_str = [''.join(s) for s in itertools.product(*X_str)] # generate list of strings of 4 aa seqs

    p = torch.Tensor([probs[key] for key in X_str])
    
    return -1 * torch.sum(p)

def obj_RHS(X, probs, n):
    """ Takes in library X, probabilities, and batch size n.
    
    Expects X to be a list of tuples, and probs to be a dictionary.
    
    Returns RHS of objective to be maximized (a supermodular function):
    sum of probabilities times expression with N and n. """
    
    # if X empty or does not have aa at each position, return 0
    if len([i for i in range(4) if i not in [tup[1] for tup in X]]) > 0:
        return torch.Tensor([0.0])[0]
    
    N = 1 # represents the product of sequence of # aas at each position
    for i in X:
        N *= len(i)
    
    # filter thru probs to find prob of x's in X
    X.sort(key=lambda tup: tup[1])

    X_str = [[tup[0] for i, tup in enumerate(X) if tup[1] == j] for j in range(4)] # generate list of lists of strings
    X_str = [''.join(s) for s in itertools.product(*X_str)] # generate list of strings of 4 aa seqs

    p = torch.Tensor([probs[key] for key in X_str])
    obj = torch.sum(p) * (1 - 1 / N) ** n
    
    return -1 * obj

def objective(X, probs, n):
    """ Objective (negative of objective to be maximized) to be minimized. """
    return obj_LHS(X, probs) - obj_RHS(X, probs, n)

In [None]:
# USES/RETURNS STRINGS, RETURNS ALL 24 POSSIBILITIES

def baseline_fixed(wt, X, y): # deterministic
    """ Takes in wildtype sequence, X, and y to compute baseline that creates 
    optimal sequence from X's given optimal amino acids (those with max y-values) 
    at each position out of the four possible positions in the wildtype sequence 
    by fixing the three other positions, then continues onto the next position in
    the wildtype sequence by fixing the best amino acid in the previous position.
    So the fixed substring is not necessarily a fixed substring of the wildtype sequence.
    
    Note: wildtype sequence expected as string. X expected as an array or list of
    one-hot encodings.
    
    Returns list of all possible 24 optimal untested variants (as a string). """  
    
    X_decode = [decode_X(x) for x in X]
    baseline = []
    
    perms = list(itertools.permutations(np.arange(len(wt)))) # list of all possible permutations for picking which aa to vary
    
    for perm in perms:
        seq = list(wt) # seq starts out as wt seq, will store variant after iteration through perm
        
        for i in perm:
            fixed = ''.join(seq) # fixed substring

            # index of xs in X with fixed substring
            index = [j for j, x in enumerate(X_decode) if fixed[0:i] == x[0:i] and fixed[i + 1:len(fixed)] == x[i + 1:len(x)]] 
            ys = [y[j] for j in index] # stores y values of x's in X with those 3 fixed amino acids

            max_ind = np.where(ys==max(ys))[0][0] # takes first occurrence of index with maximum y value

            seq[i] = X_decode[index[max_ind]][i]
  
        baseline.append(''.join(seq))
    
    return baseline

wt = decode_X(X[150614]) # wt as string
seqs = baseline_fixed(wt, X, y)
print("aa: {}".format(seqs))

# find y-values corresponding to 24 possible baselines from baseline_fixed() --> take aa seq with max y

seqs = list(set(seqs)) # remove duplicates
X_decode = [decode_X(x) for x in X]
ys_baseline = [y[X_decode.index(x)] for x in seqs]
max_baseline = seqs[ys_baseline.index(max(ys_baseline))]

y_seq1 = max(ys_baseline)
print("best baseline: {}".format(max_baseline))
print("y value: {}".format(y_seq1))
print("global max: {}".format(np.max(y)))

In [None]:
# USES/RETURNS STRINGS

def baseline_vary(wt, X, y): # deterministic
    """ Takes in wildtype sequence, X, and y to compute baseline that creates 
    optimal sequence from X's given optimal amino acids (those with max y-values) at each 
    position out of the four possible positions in the wildtype sequence by fixing the three
    other positions, then takes the best amino acid at each position. The fixed substring
    in each iteration is a substring of the wildtype sequence.
    
    Note: wildtype sequence expected as a string. X expected as an array or list of
    one-hot encodings.
    
    Returns optimal untested variant (as a string). """
    
    X_decode = [decode_X(x) for x in X]
    baseline = "" # stores baseline untested variant to be returned
    wt = list(wt)
    
    for i in range(4): # vary amino acid in each position
        fixed = ''.join(wt) # list of 3 fixed amino acids in each iteration through wt seq
        
        # index of xs in X with fixed substring
        index = [j for j, x in enumerate(X_decode) if fixed[0:i] == x[0:i] and fixed[i + 1:len(fixed)] == x[i + 1:len(x)]] 
        ys = [y[j] for j in index] # stores y values of x's in X with those 3 fixed amino acids
        
        max_ind = np.where(ys==max(ys))[0][0] # takes first occurrence of index with maximum y value

        # store amino acid in position being varied in baseline
        baseline += X_decode[index[max_ind]][i]
    
    return baseline

wt = decode_X(X[150614])  # wt as string
seq2 = baseline_vary(wt, X, y)
print("aa: {}".format(seq2))

y_seq2 = y[X_decode.index(seq2)]
print("y value: {}".format(y_seq2))
print("global max: {}".format(np.max(y)))

In [None]:
def avail_aa(X):
    """ Takes in a library X and returns a list of tuples of the available amino acids 
    at each position that can be added to the library. """
    
    amino_acids = 'ARNDCQEGHILKMFPSTWYV'
    return [(aa, i) for i in range(4) for aa in amino_acids if (aa, i) not in X]

In [None]:
def baseline_greedy(probs, seed, n):
    """ Takes in probabilities, seed (the best prediction), and batch size
    to create baseline optimal library using the Greedy algorithm. The algo 
    starts out with the seed then continues to add amino acids until obj
    stops increasing.
    
    Note: probs expected as a dictionary, and seed expected as list of tuples.
    
    Returns optimal untested library. """
    
    X = seed # library X starts with seed
    
    obj = objective(X, probs, n)
    aa = avail_aa(X) # determine available/unincluded amino acids at each position of X

    while True:
        lst = [objective(X + [a], probs, n) for a in aa] # lst of obj's for library w each available aa added

        index, obj_next = min(enumerate(lst), key=operator.itemgetter(1)) # determine which aa maximizes obj
        if obj_next > obj: # if obj stops decreasing, exit
            break
        else:
            X.append(aa[index]) # add aa that maximizes obj to X
            obj = obj_next
            aa.remove(aa[index])
            
    return X

In [None]:
def get_seed(probs):
    """ Takes in a dictionary of amino acids to probabilities as
    generated by the get_predictions() function, and returns the 
    seed (the four amino acid seq with the best prediction, aka the 
    highest probabilitiy). 
    
    Returns a list of tuples representing the seed.
    
    Currently, 'SSSG' is the seed. """
    
    seq = max(probs.items(), key=operator.itemgetter(1))[0]
    return [(aa, i) for aa, i in zip(seq, range(4))]

seed = get_seed(dic)

In [None]:
def encode_X(X):
    """ Takes in a string of four amino acids and encodes it
    to return a one-hot encoding. """
    
    amino_acids = 'ARNDCQEGHILKMFPSTWYV'
    
    enc = np.array([0.] * 80)
    pos_X = [amino_acids.find(char) for char in X] # positions of amino acids
    for i, pos in enumerate(pos_X):
        enc[pos + i * 20] = 1.0
    return enc

# test on AWSS
# [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#   0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#   0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,
#   0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#   0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
#   0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
#   0.,  0.])

# encode_X("AWSS")

In [None]:
# ##it should eventually come to a place where adding 
# #items to the set doesn't improve the objective (or it's added everything to the set)

# lib_prev = []
# lst_ys = []
# it = 0
# while True:
#     it += 1
#     print(it)
#     lib = baseline_greedy(dic, seed, 100)
#     lib.sort(key=lambda tup: tup[1])

#     lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
#     lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
#     encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings
    
#     X_decode = [decode_X(x) for x in X]
    
#     index = [X_decode.index(x) for x in lib_str]
#     y_new = np.array([y[i] for i in index])
#     lst_ys.append(y_new)
    
#     dic, means = get_predictions(encs, y_new, X, its=500)
#     preds.append(means)
#     seed = get_seed(dic)
#     lib = baseline_greedy(dic, seed, 100)
#     print(lib)
    
#     if lib == lib_prev:
#         break
#     if it == 5:
#         break
        
#     lib_prev = lib

In [None]:
##it should eventually come to a place where adding 
#items to the set doesn't improve the objective (or it's added everything to the set)


## FIRST ITERATION
lst_ys = []
libs = []

lib = baseline_greedy(dic, seed, 100)
print(lib)
lib.sort(key=lambda tup: tup[1])

lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
libs.append(lib_str)
encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings

X_decode = [decode_X(x) for x in X]

index = [X_decode.index(x) for x in lib_str]
y_new = np.array([y[i] for i in index])
lst_ys.append(y_new)

dic, means = get_predictions(encs, y_new, X, its=500)
preds.append(means)
seed = get_seed(dic)
print(seed)

In [None]:
## SECOND ITERATION
lib = baseline_greedy(dic, seed, 100)
print(lib)
lib.sort(key=lambda tup: tup[1])

lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
libs.append(lib_str)
encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings

index = [X_decode.index(x) for x in lib_str]
y_new = np.array([y[i] for i in index])
lst_ys.append(y_new)

dic, means = get_predictions(encs, y_new, X, its=500)
preds.append(means)
seed = get_seed(dic)
print(seed)

In [None]:
## THIRD ITERATION
lib = baseline_greedy(dic, seed, 100)
print(lib)
lib.sort(key=lambda tup: tup[1])

lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
libs.append(lib_str)
encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings

index = [X_decode.index(x) for x in lib_str]
y_new = np.array([y[i] for i in index])
lst_ys.append(y_new)

dic, means = get_predictions(encs, y_new, X, its=500)
preds.append(means)
seed = get_seed(dic)
print(seed)

In [None]:
## FOURTH ITERATION
lib = baseline_greedy(dic, seed, 100)
print(lib)
lib.sort(key=lambda tup: tup[1])

lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
libs.append(lib_str)
encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings

index = [X_decode.index(x) for x in lib_str]
y_new = np.array([y[i] for i in index])
lst_ys.append(y_new)

dic, means = get_predictions(encs, y_new, X, its=500, jitter=1e-5) # increase jitter by one mag 
preds.append(means)
seed = get_seed(dic)
print(seed)

In [None]:
## FIFTH ITERATION
lib = baseline_greedy(dic, seed, 100)
print(lib)
lib.sort(key=lambda tup: tup[1])

lib_str = [[tup[0] for i, tup in enumerate(lib) if tup[1] == j] for j in range(4)] # generate list of lists of strings
lib_str = [''.join(s) for s in itertools.product(*lib_str)] # generate list of strings of 4 aa seqs
libs.append(lib_str)
encs = np.array([encode_X(x) for x in lib_str]) # generate np array of one-hot encodings

index = [X_decode.index(x) for x in lib_str]
y_new = np.array([y[i] for i in index])
lst_ys.append(y_new)

dic, means = get_predictions(encs, y_new, X, its=500, jitter=1e-5) # increase jitter by one mag 
preds.append(means)
seed = get_seed(dic)
print(seed)

In [None]:
def get_mean_abs_err(X, y, mu, lib):
    """ Takes in X, true y values, predictions mu, and the sample X's (library) 
    that the model was trained on, and returns list of abs errors for all y's 
    not trained on and mean abs error. 
    
    Expects X as one-hot encodings, y and mu as lists of floats, and 
    lib as list of strings of four aa seqs. """
    
    str_x = [decode_X(x) for x in X]
    inds = [i for i, x in enumerate(str_x) if x in lib] # indices of each seq in lib in X
    
    y_test = list(y) # remove corresponding y's and mu's of seqs in lib
    mu_test = mu.copy()
    for i in inds:
        y_test.pop(i)
        mu_test.pop(i)
    
    errs = [abs(mu - y).item() for mu, y in zip(mu_test, y_test)]
    return (y_test, errs), np.mean(np.array(errs))

In [None]:
# Plot y vs mean error (for each iteration)

errs = [get_mean_abs_err(X, y, mu, lib)[1] for mu, lib in zip(preds, libs)]

_ = plt.title("Mean absolute error for iterations of Greedy Algorithm")
_ = plt.plot(np.arange(len(errs)), errs, marker='o', linestyle='none')

_ = plt.show() 

In [None]:
 # Plot mean error for y's vs. y's sorted (for each iteration)

abs_errs = [get_mean_abs_err(X, y, mu, lib)[0] for mu, lib in zip(preds, libs)]
_ = plt.title("Absolute error vs y's tested on for iterations of Greedy Algorithm")

for err in abs_errs: # each err is a tuple of y_test, abs errs
    sorted_ind = sorted(range(len(err[0])), key=lambda k: err[0][k]) # get indexes of y sorted
    y_sort = np.sort(err[0]) # sort y
    err_sort = [err[1][i] for i in sorted_ind]
    _ = plt.plot(y_sort, err_sort, marker='.')

_ = plt.show()

In [None]:
def ecdf(data):
    """Compute ECDF for a one-dimensional array of measurements."""
    # Number of data points: n
    n = len(data)

    # x-data for the ECDF: x
    x = np.sort(data)

    # y-data for the ECDF: y
    y = np.arange(1, n + 1) / n

    return x, y

In [None]:
# Compute ECDF
x_val, y_val = ecdf(y)

# Generate plot
_ = plt.title("ECDF of y's and deterministic baselines")
_ = plt.plot(x_val, y_val, marker='.', linestyle='none', label="orig y's")
_ = plt.plot(y_seq1, y_val[np.argwhere(x_val == y_seq1)[0][0]], marker='o', label='baseline_fixed')
_ = plt.plot(y_seq2, y_val[np.argwhere(x_val == y_seq2)[0][0]], marker='o', label='baseline_vary')
_ = plt.legend(loc='upper center', bbox_to_anchor=(1.45, 0.8), shadow=True, ncol=1)

# Label the axes
_ = plt.ylabel('ECDF')
_ = plt.xlabel('y')

# Display the plot
_ = plt.show()

In [None]:
# Plot Greedy algorithm baseline with deterministic baselines too

d = {'Iterations': [], 'Sampled ys': []}

for i in range(len(lst_ys)):
    for j in lst_ys[i]:
        d['Iterations'].append(i)
        d['Sampled ys'].append(j)
    
df = pd.DataFrame(data=d) # make dataframe of sampled ys to plot on swarmplot

sns.set(style="whitegrid")
_ = plt.title('Baselines: deterministic and Greedy algorithm')
ax = sns.swarmplot(x="Iterations", y="Sampled ys", data=df) # swamplot allows for jitter in displaying cluster of ys
_ = ax.axhline(y_seq1, color='purple', label='baseline_fixed')
_ = ax.axhline(y_seq2, color='black', label='baseline_vary')
_ = plt.legend(loc='upper center', bbox_to_anchor=(1.45, 0.8), shadow=True, ncol=1)