In [9]:
from pathlib import Path

import pandas as pd
import numpy as np
from scipy.stats import spearmanr
from sklearn.model_selection import KFold
import string

import tensorflow as tf
import tensorflow.keras.backend as K

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.initializers import Zeros, Constant
from tensorflow.keras.losses import categorical_crossentropy

import matplotlib.pyplot as plt
import seaborn as sns

# reset project root
PROJECT_ROOT = "/home/zahra/Synthetic_Landscapes/mutation-predictability"
DATA_ROOT =  "/home/zahra/Synthetic_Landscapes/Data"

from predictability.models import PottsRegressor, PottsModel
from predictability.utils import update_environment_variables
from predictability.constants import BINARY_RESIDUE_FEATURES

### For Francesca

- Refer to https://github.com/florisvdf/mutation-predictability for info on downloading Gremlin
- Change PROJECT_ROOT and DATA_ROOT
- Change folder/file names
- if loading new MSA file, rerun fxn to convert a2m (downloaded from EVcoulpings) to a3m 

In [3]:
# Only necessary when jupyter does not read EVs, replace zsh with your shell
update_environment_variables("zsh")

In [20]:
path = "/home/zahra/Synthetic_Landscapes/Potts/results/GB1/"
results_dir = Path(path)
results_dir.mkdir(exist_ok=True, parents=True)
seed = 42

In [7]:
# #convet a2m to a3m (only need to run once)
# def format_a2m(msa_path):
#     file_name = Path(msa_path).name
#     directory = Path(msa_path).parent
#     formatted_msa_path = f"{directory}/modified_{file_name}"
#     with open(msa_path, "r") as rf:
#         with open(formatted_msa_path, "w") as wf:
#             for line in rf:
#                 if line.startswith(">"):
#                     wf.write(line)
#                 else:
#                     wf.write(line.upper().replace(".", "-"))
#     return formatted_msa_path

In [6]:
def parse_fasta(filename, a3m=True):
  '''function to parse fasta file'''
  
  if a3m:
    # for a3m files the lowercase letters are removed
    # as these do not align to the query sequence
    rm_lc = str.maketrans(dict.fromkeys(string.ascii_lowercase))
    
  header, sequence = [],[]
  lines = open(filename, "r")
  for line in lines:
    line = line.rstrip()
    if line[0] == ">":
      header.append(line[1:])
      sequence.append([])
    else:
      if a3m: line = line.translate(rm_lc)
      else: line = line.upper()
      sequence[-1].append(line)
  lines.close()
  sequence = [''.join(seq) for seq in sequence]
  return header, sequence
  
def mk_msa(seqs):
  '''one hot encode msa'''
  alphabet = list("ARNDCQEGHILKMFPSTWYV")
  states = len(alphabet)
  
  alpha = np.array(alphabet, dtype='|S1').view(np.uint8)
  msa = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)  
  for n in range(states):
    msa[msa == alpha[n]] = n  
  msa[msa > states] = states-1
  
  return np.eye(states)[msa]

In [7]:
def GREMLIN_simple(msa, msa_weights=None, lam=0.01, 
                   opt=None, opt_rate=None,
                   opt_batch=None, opt_epochs=100):
  '''
  ------------------------------------------------------
  inputs
  ------------------------------------------------------
   msa         : msa input       shape=(N,L,A)
   msa_weights : weight per seq  shape=(N,)
  ------------------------------------------------------
  optional inputs
  ------------------------------------------------------
   lam         : L2 regularization weight
   opt         : optimizer
   opt_rate    : learning rate
   opt_batch   : batch size
   opt_epochs  : number of epochs
  ------------------------------------------------------
  outputs 
  ------------------------------------------------------
   v           : conservation    shape=(L,A)
   w           : coevolution     shape=(L,A,L,A)
  ------------------------------------------------------
  '''
  
  # [N]umber of sequences, [L]ength, and size of [A]lphabet
  N,L,A = msa.shape
    
  # reset any open sessions/graphs
  K.clear_session()
  
  #############################
  # the model
  #############################  
  
  # constraints
  def cst_w(x):
    '''symmetrize, set diagonal to zero'''
    x = (x + K.transpose(x))/2    
    zero_mask = K.constant((1-np.eye(L))[:,None,:,None],dtype=tf.float32)
    x = K.reshape(x,(L,A,L,A)) * zero_mask
    return K.reshape(x,(L*A,L*A))
  
  # initialiation
  if msa_weights is None:
    Neff = N
    pssm = msa.sum(0)
  else:
    Neff = msa_weights.sum()
    pssm = (msa.T*msa_weights).sum(-1).T
  
  ini_v = np.log(pssm + lam * np.log(Neff))
  ini_v = Constant(ini_v - ini_v.mean(-1,keepdims=True))
  ini_w = Zeros
  
  # regularization
  lam_v = l2(lam/N)
  lam_w = l2(lam/N * (L-1)*(A-1)/2)
  
  # model
  model = Sequential()
  model.add(Flatten(input_shape=(L,A)))
  model.add(Dense(units=L*A,
                  kernel_initializer=ini_w,
                  kernel_regularizer=lam_w,
                  kernel_constraint=cst_w,
                  bias_initializer=ini_v,
                  bias_regularizer=lam_v)) 
  model.add(Reshape((L,A)))
  model.add(Activation("softmax"))
  
  #############################
  # compile model
  #############################
  # loss function = CCE = -Pseudolikelihood
  @tf.function
  def CCE(true,pred):
    return K.sum(-true * K.log(pred + 1e-8),axis=(1,2))
  
  # optimizer settings
  if opt is None: opt = Adam
  if opt_rate is None: opt_rate = 0.1 * np.log(Neff)/L
  if opt_batch is None: opt_batch = N

  model.compile(opt(opt_rate),CCE)  
    
  #############################
  # fit model
  #############################
  model.fit(msa, msa, sample_weight=msa_weights,
            batch_size=opt_batch, epochs=opt_epochs,
            verbose=False)

  # report loss
  loss = model.evaluate(msa, msa, sample_weight=msa_weights, verbose=False) * N
  print(f"loss: {loss}")
  
  #############################
  # return weights
  #############################
  w,v = model.get_weights()
  return v.reshape((L,A)), w.reshape((L,A,L,A))

In [11]:
%%time

names, seqs = parse_fasta(str(DATA_ROOT + "/GB1/GB1_b0.3_b0.5_joined.a3m"), a3m=True)
msa = mk_msa(seqs)

#potts_model = PottsRegressor(msa_path=str(DATA_ROOT / "amylase/msa.a3m"))
V, W = GREMLIN_simple(msa, msa_weights=None)

2023-12-11 13:50:33.048328: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 19868 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:3b:00.0, compute capability: 8.6
2023-12-11 13:50:33.049792: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22280 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:af:00.0, compute capability: 8.6
2023-12-11 13:50:38.113735: I tensorflow/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2023-12-11 13:50:46.218321: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f43080f44d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-12-11 13:50:46.218417: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute 

loss: 731118.1656723022
CPU times: user 49.6 s, sys: 39.1 s, total: 1min 28s
Wall time: 56 s


In [21]:
path = str(results_dir) + '/output.npz'
data = np.savez(path, V, W)

/home/zahra/Synthetic_Landscapes/Potts/results/GB1/output.npz
