# Installing Facebook embeddings template

In [None]:
!pip install git+https://github.com/facebookresearch/esm.git

In [3]:
import torch
import esm

# Load 34 layer model
model, alphabet = esm.pretrained.esm1_t34_670M_UR50S()
# If you have a GPU, put the model on it
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

batch_converter = alphabet.get_batch_converter()


In [4]:
import random
from collections import Counter
from tqdm import tqdm

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.svm import SVC, SVR
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression, LinearRegression, SGDRegressor

In [5]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

# PART 1: Generating prediction model

# 1.1 Preparing data for training

In [None]:
# Prepare data (two protein sequences)

FASTA_PATH="./training_seqs.fasta" # Fasta to train

data=[]
ys = []
Xs = []
for header, sequence in esm.data.read_fasta(FASTA_PATH):
  data.append((header, sequence))
  body = (header.split(' '))[-1]
  ys.append(float(body))
print(ys)
print(data)

# 1.2 Building embeddings

In [8]:
sequence_embeddings = []
# build embeddings
for batch_seqs in batch(data,10):
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_seqs)

    # Extract per-residue embeddings (on GPU)
    # batch_tokens_cuda = batch_tokens.to(device="cuda", non_blocking=True)
    batch_tokens_cuda = batch_tokens.to(device, non_blocking=True)
    with torch.no_grad():
        results = model(batch_tokens_cuda, repr_layers=[34])
    token_embeddings = results["representations"][34]
    # Generate per-sequence embeddings via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    for i, (_, seq) in enumerate(batch_seqs):
        sequence_embeddings.append(token_embeddings[i, 1:len(seq) + 1].mean(0))

In [9]:
print(len(sequence_embeddings[0]))

1280


# 1.3 Creating Training set & Test set

In [10]:
# split training and test set
Xs=[t.cpu().data.numpy() for t in sequence_embeddings]
train_size = 0.8
Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42)

# 1.4 Beginning of the training block

In [11]:
knn_grid = {
    'n_neighbors': [5, 10],
    'weights': ['uniform', 'distance'],
    'algorithm': ['ball_tree', 'kd_tree', 'brute'],
    'leaf_size' : [15, 30],
    'p' : [1, 2],
}

svm_grid = {
    'C' : [0.1, 1.0, 10.0],
    'kernel' :['linear', 'poly', 'rbf', 'sigmoid'],
    'degree' : [3],
    'gamma': ['scale'],
}

rfr_grid = {
    'n_estimators' : [100],
    'criterion' : ['squared_error', 'absolute_error'],
    'max_features': ['sqrt', 'log2'],
    'min_samples_split' : [2, 10],
    'min_samples_leaf': [1, 4]
}
lgr_grid = {

}

In [None]:
# Training Block!!!
cls_list = [KNeighborsRegressor, SVR, RandomForestRegressor]
param_grid_list = [knn_grid, svm_grid, rfr_grid]
result_list = []
grid_list = []
for cls_name, param_grid in zip(cls_list, param_grid_list):
    print(cls_name)
    grid = GridSearchCV(
        estimator = cls_name(),
        param_grid = param_grid,
        scoring = 'r2',
        verbose = 1,
        n_jobs = -1 # use all available cores
    )
    grid.fit(Xs_train, ys_train)
    result_list.append(pd.DataFrame.from_dict(grid.cv_results_))
    grid_list.append(grid)

# 1.5 Testing the trained model

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats

# Assuming grid_list, Xs_test, and ys_test are already defined

for i, grid in enumerate(grid_list):
    print(grid.best_estimator_)
    print()

    # Predictions
    preds = grid.predict(Xs_test)

    # Calculate Spearman's correlation
    rho, p_value = scipy.stats.spearmanr(ys_test, preds)

    # Create a DataFrame and save to CSV
    df = pd.DataFrame({'Actual Kcat/Km': ys_test, 'Predicted Kcat/Km': preds})
    csv_filename = f'grid_element_{i}_kcat_km_data.csv'
    df.to_csv(csv_filename, index=False)
    print(f'Data saved to {csv_filename}')

    # Create scatter plot
    plt.figure(figsize=(8, 6))
    plt.scatter(ys_test, preds, alpha=0.7)
    plt.xlabel('Actual Kcat/Km')
    plt.ylabel('Predicted Kcat/Km')
    plt.grid(True)

    # Annotate with Spearman's rho
    plt.annotate(f'Spearman\'s rho = {rho:.2f}\nP-value = {p_value:.2e}',
                 xy=(0.05, 0.85), xycoords='axes fraction',
                 fontsize=12, bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='white'))

    # Save plot as an image
    img_filename = f'grid_element_{i}_kcat_km_plot.png'
    plt.savefig(img_filename)
    print(f'Plot saved to {img_filename}')
    plt.close()

    print('\n', '-' * 80, '\n')


In [None]:
for i, grid in enumerate(grid_list):
    print(grid.best_estimator_)
    print()
    preds = grid.predict(Xs_test)
    print(f'{scipy.stats.spearmanr(ys_test, preds)}')
    print('\n', '-' * 80, '\n')
    # Calculate Spearman's correlation
    rho, p_value = scipy.stats.spearmanr(ys_test, preds)

    # Create a DataFrame and save to CSV
    df = pd.DataFrame({'Actual Values': ys_test, 'Predicted Values': preds})
    csv_filename = f'grid_element_{i}_data.csv'
    df.to_csv(csv_filename, index=False)
    print(f'Data saved to {csv_filename}')

    # Create scatter plot
    plt.figure(figsize=(8, 6))
    plt.scatter(ys_test, preds, alpha=0.7)
    plt.title('Spearman Correlation between Actual and Predicted Values')
    plt.xlabel('Actual Values')
    plt.ylabel('Predicted Values')
    plt.grid(True)

    # Annotate with Spearman's rho
    plt.annotate(f'Spearman\'s rho = {rho:.2f}\nP-value = {p_value:.2e}',
              xy=(0.05, 0.75), xycoords='axes fraction',
              fontsize=12, bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='white'))

    plt.show()

In [None]:
import matplotlib.pyplot as plt
import scipy.stats


# Calculate Spearman's correlation
rho, p_value = scipy.stats.spearmanr(ys_test, preds)

# Create scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(ys_test, preds, alpha=0.7)
plt.title('Spearman Correlation between Actual and Predicted Values')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.grid(True)

# Annotate with Spearman's rho
plt.annotate(f'Spearman\'s rho = {rho:.2f}\nP-value = {p_value:.2e}',
              xy=(0.05, 0.95), xycoords='axes fraction',
              fontsize=12, bbox=dict(boxstyle="round,pad=0.3", edgecolor='black', facecolor='white'))

plt.show()


In [None]:
topredict=[('wt','MEPSSLELPADTVQRIAAELKCHPTDERVALHLDEEDKLRHFRECFYIPKIQDLPPVDLSLVNKDENAIYFLGNSLGLQPKMVKTYLEEELDKWAKIAAYGHEVGKRPWITGDESIVGLMKDIVGANEKEIALMNALTVNLHLLMLSFFKPTPKRYKILLEAKAFPSDHYAIESQLQLHGLNIEESMRMIKPREGEETLRIEDILEVIEKEGDSIAVILFSGVHFYTGQHFNIPAITKAGQAKGCYVGFDLAHAVGNVELYLHDWGVDFACWCSYKYLNAGAGGIAGAFIHEKHAHTIKPALVGWFGHELSTRFKMDNKLQLIPGVCGFRISNPPILLVCSLHASLEIFKQATMKALRKKSVLLTGYLEYLIKHNYGKDKAATKKPVVNIITPSHVEERGCQLTITFSVPNKDVFQELEKRGVVCDKRNPNGIRVAPVPLYNSFHDVYKFTNLLTSILDSAETKN'),('best_patent','MEPSSLELPADTVQRIAAELKCHPTDERVALHLDEEDKLRHFRECFYIPKIQDLPPVDLSLVNKDEDAIYFNGNSLGLQPKMVKTYLEEELDKWAKIAINGWFEGDSPWIHYDESIVGLMKDIVGANEKEIVLMNTLTVNLHLLMLSFFKPTPKRYKILLEAKAFPSDHYAIESQLQLHGLNIEESMRIIKPREGEETLRIEDILEVIEKEGDSIAVILFSGIHYYTGQHFNIPAITKAGQAKGCYVGFDLAHAVGNVELYLHDWGVDFACWCGYKYLNSSPGGIAGAFIHEKHAHTIKPALVGWFGHELSTRFKMDNKLQLIPGVCGFRCSTPPILLVCILHASLEIFKQATMKALRKKSVLLTGYLEYLIKHNYGKDKAATKKPVVNIITPSHVEERGCQLTLTFNVPNKDVFQELEKRGVVCDKRNPNGIRVAPVPLYNSFHDVYKFTNLLTSILDSAETKN'),('best_mut','MEPSSLELPADTVQRIAAELKCHPTDERVALHLDEEDKLRHFRECFYIPKIQDLPPVDLSLVNKDEDAIYFNGNSLGLQPKMVKTYREEELDKWAKIAINGWFEGDSPWIHYDESIVGLMKDIVGANEKEIVLWYTLTHMLHLLMLSFFKPTPKRYKILLYAKAFPSDHYAIESQLQLHGLNIEESMRIIKPREGEETLRIEDILEVIEKEGDSIAVITFSGIHYMTGQHFNIPAITKALQAKGCYVGFDQAHAVGNVELYLHDWGVDFACNCGYKYLNSSPGWIQGWFCHEKHAHTIKPALVGWFGHELSTRFKMDNKLQLIPGVCGFRCSTPNHWLVCILHAPLENFKQATMKALRKKSVLLTGYLEYLIKHNYGKDKAATKKPVVNIITPSHVEERGCQLTLTFNVPNKDVFQELEKRGVVCDKRNPNGIRVAPVPLYNSFHDVYKFTNLLTSILDSAETKN'),('worst_mut','MEPSSLELPADTVQRIAAELKCHPTDERVALHLDEEDKLRHFRECFYIPKIQDLPPVDLSLVNKDEDAIYFNGNSLGLQPKMVKTYYEEELDKWAKIAINGWFEGDSPWIHYDESIVGLMKDIVGANEKEIVLYFTLTDQLHLLMLSFFKPTPKRYKILLNAKAFPSDHYAIESQLQLHGLNIEESMRIIKPREGEETLRIEDILEVIEKEGDSIAVIMFSGIHYETGQHFNIPAITKAMQAKGCYVGFDPAHAVGNVELYLHDWGVDFACVCGYKYLNSSPGIINGRFDHEKHAHTIKPALVGWFGHELSTRFKMDNKLQLIPGVCGFRCSTPKRKLVCILHAHLELFKQATMKALRKKSVLLTGYLEYLIKHNYGKDKAATKKPVVNIITPSHVEERGCQLTLTFNVPNKDVFQELEKRGVVCDKRNPNGIRVAPVPLYNSFHDVYKFTNLLTSILDSAETKN'),('var_93','MEPSPLELPADTVQRIASELRCHPTDERVALRLDEEDELRHFREYFYIPKMQDLPPIDLSLVNKDENAIYFLGNSLGLQPKMVKTYLEEELDKWAKMGAYGHEVGKRPWITGDETIVGLMTDIVGANEKEIALMNGLTVNLHLLLLSFFKPTPKRYKILLEAKAFPSDHYAIESQLQLHGLNVEKSMRIIKPREGEETLRTEDILEVIEKEGDSIAVILFSGVHFYTGQLFNIPAITKAGQAKGCFVGFDLAHAVGNVELHLHDWGVDFACWCSYKYLNSGAGGLAGAFVHEKHAYTIKPALVGWFGHELSTRFKMDNKLQLIPGVNGFRISNPPILLVCSLHASLEIFKQATMKALRRKSILLTGYLEYLIKHYYSKDKAETKKPIVNIITPSRIEERGCQLTLTFSVPMKYVFQELEKRGVVCDKREPNGIRVAPVPLYNSFHDVYKFIELLTSVLDSAETK')]

for batch_seqs in batch(topredict, 1):
  batch_labels, batch_strs, batch_tokens = batch_converter(batch_seqs)
  # build embeddings
  # batch_tokens_cuda = batch_tokens.to(device="cuda", non_blocking=True)
  batch_tokens_cuda = batch_tokens.to(device, non_blocking=True)

  with torch.no_grad():
    results = model(batch_tokens_cuda, repr_layers=[34])
  token_embeddings = results["representations"][34]

  # Generate per-sequence embeddings via averaging
  # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
  sequence_embeddings = []
  for i, (_, seq) in enumerate(batch_seqs):
    sequence_embeddings.append(token_embeddings[i, 1:len(seq) + 1].mean(0))

  predict_seqs_embeddings=[t.cpu().data.numpy() for t in sequence_embeddings]
  preds=[]
  for grid in grid_list:
    pred = grid.predict(predict_seqs_embeddings)
    preds.append(pred)
  for i in range(0, len(batch_seqs)):
    #f.write("{} {} {} {}\n".format(batch_seqs[i][0], preds[0][i],  preds[1][i],  preds[2][i]))
    print(batch_seqs[i][0], preds[0][i], preds[1][i], preds[2][i])

# PART 2: Building embeddings (completely new data) & feed it into the prediction model

In [None]:
#@title Install colabfold and search training sequences against colabfold database with mmseqs2
!pip install "colabfold[alphafold] @ git+https://github.com/sokrypton/ColabFold"
import requests
import hashlib
import tarfile
import time
import pickle
import os
import re

import random
import tqdm.notebook

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
from matplotlib import collections as mcoll

import os
import requests
import tqdm.notebook
import random
import tarfile

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
                use_templates=False, filter=None, host_url="https://api.colabfold.com"):
  
  def submit(seqs, mode, N=101):    
    n,query = N,""
    for seq in seqs: 
      query += f">{n}\n{seq}\n"
      n += 1
      
    while True:
      try:
        # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
        # "good practice to set connect timeouts to slightly larger than a multiple of 3"
        res = requests.post(f'{host_url}/ticket/msa', data={'q':query,'mode': mode}, timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    try: out = res.json()
    except ValueError: out = {"status":"UNKNOWN"}
    return out

  def status(ID):
    while True:
      try:
        res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    try: out = res.json()
    except ValueError: out = {"status":"UNKNOWN"}
    return out

  def download(ID, path):
    while True:
      try:
        res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    with open(path,"wb") as out: out.write(res.content)
  
  # process input x
  seqs = [x] if isinstance(x, str) else x
  
  # compatibility to old option
  if filter is not None:
    use_filter = filter
    
  # setup mode
  if use_filter:
    mode = "env" if use_env else "all"
  else:
    mode = "env-nofilter" if use_env else "nofilter"

  mode += "-m8output"
  
  # define path
  path = f"{prefix}_{mode}"
  if not os.path.isdir(path): os.mkdir(path)

  # call mmseqs2 api
  tar_gz_file = f'{path}/out.tar.gz'
  N,REDO = 101,True
  
  # deduplicate and keep track of order
  seqs_unique = sorted(list(set(seqs)))
  Ms = [N+seqs_unique.index(seq) for seq in seqs]
  
  # lets do it!
  if not os.path.isfile(tar_gz_file):
    TIME_ESTIMATE = 150 * len(seqs_unique)
    with tqdm.notebook.tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
      while REDO:
        pbar.set_description("SUBMIT")
        
        # Resubmit job until it goes through
        out = submit(seqs_unique, mode, N)
        while out["status"] in ["UNKNOWN","RATELIMIT"]:
          # resubmit
          time.sleep(5 + random.randint(0,5))
          out = submit(seqs_unique, mode, N)

        if out["status"] == "ERROR":
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

        if out["status"] == "MAINTENANCE":
          raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')

        # wait for job to finish
        ID,TIME = out["id"],0
        pbar.set_description(out["status"])
        while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
          t = 5 + random.randint(0,5)
          time.sleep(t)
          out = status(ID)    
          pbar.set_description(out["status"])
          if out["status"] == "RUNNING":
            TIME += t
            pbar.update(n=t)
          #if TIME > 900 and out["status"] != "COMPLETE":
          #  # something failed on the server side, need to resubmit
          #  N += 1
          #  break
        
        if out["status"] == "COMPLETE":
          if TIME < TIME_ESTIMATE:
            pbar.update(n=(TIME_ESTIMATE-TIME))
          REDO = False

      # Download results
      download(ID, tar_gz_file)

  # prep list of a3m files
  m8_files = [f"{path}/uniref.m8"]
  if use_env: m8_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.m8")
  
  # extract a3m files
  if not os.path.isfile(m8_files[0]):
    with tarfile.open(tar_gz_file) as tar_gz:
      tar_gz.extractall(path)  

  m8_lines = {}
  for m8_file in m8_files:
    for line in open(m8_file,"r"):
      if len(line) > 0:
        if "\x00" in line:
          line = line.replace("\x00","")
        M = int(line.split()[0])
        if M not in m8_lines: m8_lines[M] = []
        m8_lines[M].append(line)
  
  # return results
  m8_lines = ["".join(m8_lines[n]) for n in Ms]
  
  if isinstance(x, str):
    return m8_lines[0]
  else:
    return m8_lines

# Run mmseqs with the query sequence
m8_output = run_mmseqs2([x[1] for x in data], "tmp")

amino_acid = ["G", "A", "L", "M", "K", "F", "W", "Q", "E", "S", "P", "V", "I", "C", "Y", "H", "R", "N", "D", "T"]
with open("./esm1b_target.m8", 'w') as new:
  for each in m8_output:
    for line in each.split("\n"):
      if line != "":
        seq = line.split()[-1].upper()
        new_seq = ""
        for char in seq:
          if char in amino_acid:
            new_seq += char
        new.write("%s %s\n"%(line.split()[1], new_seq))
!sort -u -k2,2 esm1b_target.m8 | sort -k1,1 | awk 'BEGIN {cnt=2;name="";} {if (name == $1) {print $1"_colab_"cnt, $2;cnt += 1} else {print;name=$1;cnt=2}}' > esm1b_nonredundant.m8

# Gather the search result and make input file for esm1b
#searched_sequences = []
#searched_names = []
with open("./esm1b_nonredundant.m8") as f:
  with open("./esm1b_target_input.fasta", 'w') as new:
    line = f.readline()
    while line:
      if len(line[:-1].split()[1]) < 1023:
        new.write(">%s\n%s\n"%(line.split()[0], line[:-1].split()[1].upper()))
      line = f.readline()

In [18]:
METAGENOME_FASTA_PATH="./esm1b_target_input.fasta"


topredict = []
with open('./target_prediction.tsv', 'w') as f:
  for header, sequence in esm.data.read_fasta(METAGENOME_FASTA_PATH):
    topredict.append((header, sequence))

  for batch_seqs in batch(topredict, 1):
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_seqs)
    # build embeddings
    # batch_tokens_cuda = batch_tokens.to(device="cuda", non_blocking=True)
    batch_tokens_cuda = batch_tokens.to(device, non_blocking=True)

    with torch.no_grad():
      results = model(batch_tokens_cuda, repr_layers=[34])
    token_embeddings = results["representations"][34]

    # Generate per-sequence embeddings via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_embeddings = []
    for i, (_, seq) in enumerate(batch_seqs):
      sequence_embeddings.append(token_embeddings[i, 1:len(seq) + 1].mean(0))

    predict_seqs_embeddings=[t.cpu().data.numpy() for t in sequence_embeddings]
    preds=[]
    for grid in grid_list:
      pred = grid.predict(predict_seqs_embeddings)
      preds.append(pred)
    for i in range(0, len(batch_seqs)):
      f.write("{}\t{}\t{}\t{}\n".format(batch_seqs[i][0], preds[0][i],  preds[1][i],  preds[2][i]))

# PART 3: Emptying cuda cache

In [None]:

if torch.cuda.is_available():
    torch.cuda.empty_cache()