<img src="https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png" height="200" align="right" style="height:240px">

##Enzyme search using ESM-1b and ColabFold database
Find homologous proteins with similar/better enzymatic properties. You will need to provide a fasta file with numerical value associared with your sequences e.g. Kcat/Km, toxicity, ... The Colab will search similiar proteins similar to input enzymes with [ColabFold](https://github.com/sokrypton/ColabFold), sorted by the numerical value predicted by a [ESM-1b](https://github.com/facebookresearch/esm) based-regressor.

1.    Prepare a fasta file containing the sequence and numerical value representing the sequence (see the example fasta format below)
```
>seq_1 8.104
PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK
>seq_2 2.04
MIAQIHILEGRSDEDKETLIRRVSEAISRSLDAPLTSVRVIIMEMAKGHFGGELASK
```

2.   `Runtime` -> `Run all` upload the file (option appears under "File upload")


<b>Caution!</b><br>
If your dataset has high ratio of same numerical values, correlation calculation might cause an error!
<br><br>

For more detail about what's happening in this colab, please look into the explanation on the bottom.


In [None]:
#@title File upload.
from google.colab import files
import io
import warnings
#separator = "" #@param {type:"string"}
#if separator == "":
#  separator = "\t"
#@markdown - Run this cell and upload a fasta file
#@markdown - Click third cell (meaning skip one below this), then hit `Runtime` -> `Run after`
# Read in fasta file
uploaded = files.upload()
names = []
values = []
name_value = {}
sequences = []
seq = ""
for k, v in uploaded.items():
  for line in v.decode("utf-8").split("\n"):
    if len(line) > 0 and line[0] == ">":
      names.append(line.split()[0][1:])
      values.append(float(line.split()[-1]))
      name_value[line.split()[0][1:]] = float(line.split()[-1])
      if seq != "":
        sequences.append(seq.upper())
        seq = ""
    else:
      if line != "":
        seq += line
  sequences.append(seq)
assert len(names) == len(values)
assert len(values) == len(sequences)
if len(list(set(values))) == 1:
  warnings.warn("Warning: only one value found. Changing procedure to get distance, not a regressor score")
  do_distance = True

# Write a fasta file for esm1b (write proteins with <=1024 amino acid length)
cnt = 0
with open("/content/esm1b_input.fasta", 'w') as new:
  for i in range(len(names)):
    if len(sequences[i]) < 1023:
      new.write(">%s\n%s\n"%(names[i], sequences[i]))
      cnt += 1

# Make a warning if cnt is smaller than 40
if cnt < 40:
  warnings.warn("Warning: Training data size smaller than 40! We cannot ensure the result quality if the size of training data is too small")
if cnt < 10:
  warnings.warn("Warning: Training data size smaller than 10, cannot conduct five fold validation, the last layer will be chosen from ESM-1b")

In [None]:
#@title Options
do_distance = False #@param {type:"boolean"}
#@markdown  - Select `do_distance` when you only have positive samples with no numerical labels.<br>If you selected this option, please set all the numerical values to 1.
Regressor = "Random Forest Regressor" #@param ["Random Forest Regressor", "Linear Regression", "K Neighbors Regressor", "Logistic Regression"]
#@markdown - Use Logstic Regression only when you have labels of 0 and 1
iteration = 5 #@param {type:"integer"}
#@markdown - Decide how many times you will repeat five fold validation<br>
n_neighbors = 5 #@param {type:"integer"}
#@markdown - Set `n_neighbors` for K Neighbors Regressor, this has to be same or smaller than size of training data


In [None]:
#@title Install Colabfold, ESM-1b and ProtT5 (Currently doesn't install ProtT5)
!pip install git+https://github.com/facebookresearch/esm.git
!pip install "colabfold[alphafold] @ git+https://github.com/sokrypton/ColabFold"
#!pip install -q SentencePiece transformers

from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import RandomForestRegressor

import requests
import hashlib
import tarfile
import time
import pickle
import os
import re

import random
import tqdm.notebook

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patheffects
from matplotlib import collections as mcoll

import os
import requests
import tqdm.notebook
import random
import tarfile

import torch

from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
try:
    model_esm
except NameError:
    model_esm, alphabet_esm = pretrained.load_model_and_alphabet("esm1b_t33_650M_UR50S")

import pickle

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
                use_templates=False, filter=None, host_url="https://api.colabfold.com"):
  
  def submit(seqs, mode, N=101):    
    n,query = N,""
    for seq in seqs: 
      query += f">{n}\n{seq}\n"
      n += 1
      
    while True:
      try:
        # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
        # "good practice to set connect timeouts to slightly larger than a multiple of 3"
        res = requests.post(f'{host_url}/ticket/msa', data={'q':query,'mode': mode}, timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    try: out = res.json()
    except ValueError: out = {"status":"UNKNOWN"}
    return out

  def status(ID):
    while True:
      try:
        res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    try: out = res.json()
    except ValueError: out = {"status":"UNKNOWN"}
    return out

  def download(ID, path):
    while True:
      try:
        res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02)
      except requests.exceptions.Timeout:
        continue
      break

    with open(path,"wb") as out: out.write(res.content)
  
  # process input x
  seqs = [x] if isinstance(x, str) else x
  
  # compatibility to old option
  if filter is not None:
    use_filter = filter
    
  # setup mode
  if use_filter:
    mode = "env" if use_env else "all"
  else:
    mode = "env-nofilter" if use_env else "nofilter"

  mode += "-m8output"
  
  # define path
  path = f"{prefix}_{mode}"
  if not os.path.isdir(path): os.mkdir(path)

  # call mmseqs2 api
  tar_gz_file = f'{path}/out.tar.gz'
  N,REDO = 101,True
  
  # deduplicate and keep track of order
  seqs_unique = sorted(list(set(seqs)))
  Ms = [N+seqs_unique.index(seq) for seq in seqs]
  
  # lets do it!
  if not os.path.isfile(tar_gz_file):
    TIME_ESTIMATE = 150 * len(seqs_unique)
    with tqdm.notebook.tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
      while REDO:
        pbar.set_description("SUBMIT")
        
        # Resubmit job until it goes through
        out = submit(seqs_unique, mode, N)
        while out["status"] in ["UNKNOWN","RATELIMIT"]:
          # resubmit
          time.sleep(5 + random.randint(0,5))
          out = submit(seqs_unique, mode, N)

        if out["status"] == "ERROR":
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

        if out["status"] == "MAINTENANCE":
          raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')

        # wait for job to finish
        ID,TIME = out["id"],0
        pbar.set_description(out["status"])
        while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
          t = 5 + random.randint(0,5)
          time.sleep(t)
          out = status(ID)    
          pbar.set_description(out["status"])
          if out["status"] == "RUNNING":
            TIME += t
            pbar.update(n=t)
          #if TIME > 900 and out["status"] != "COMPLETE":
          #  # something failed on the server side, need to resubmit
          #  N += 1
          #  break
        
        if out["status"] == "COMPLETE":
          if TIME < TIME_ESTIMATE:
            pbar.update(n=(TIME_ESTIMATE-TIME))
          REDO = False

      # Download results
      download(ID, tar_gz_file)

  # prep list of a3m files
  m8_files = [f"{path}/uniref.m8"]
  if use_env: m8_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.m8")
  
  # extract a3m files
  if not os.path.isfile(m8_files[0]):
    with tarfile.open(tar_gz_file) as tar_gz:
      tar_gz.extractall(path)  

  m8_lines = {}
  for m8_file in m8_files:
    for line in open(m8_file,"r"):
      if len(line) > 0:
        if "\x00" in line:
          line = line.replace("\x00","")
        M = int(line.split()[0])
        if M not in m8_lines: m8_lines[M] = []
        m8_lines[M].append(line)
  
  # return results
  m8_lines = ["".join(m8_lines[n]) for n in Ms]
  
  if isinstance(x, str):
    return m8_lines[0]
  else:
    return m8_lines

# Code for embedding extraction (esm-1b)
# Extract for every layer

def esm_embedding(input_file, model, alphabet, nogpu=False, 
                  repr_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]):
    #model, alphabet = pretrained.load_model_and_alphabet("esm1b_t33_650M_UR50S")
    model.eval()
    if isinstance(model, MSATransformer):
        raise ValueError(
            "This script currently does not handle models with MSA input (MSA Transformer)."
        )
    if torch.cuda.is_available() and not nogpu:
        model = model.cuda()
        print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(input_file)
    batches = dataset.get_batch_indices(4096, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
    )
    print(f"Read input fasta file with {len(dataset)} sequences")

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers]

    result_list = []
    esm_names = []
    for each in range(len(repr_layers)):
      result_list.append([])
    # This means inference with maximum performance layer
    if len(repr_layers) == 1:
      result_list = []
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not nogpu:
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                esm_names.append(label)
                result = {"label": label}
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                # Get mean representations
                result["mean_representations"] = {
                    layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone()
                    for layer, t in representations.items()
                }
                if len(repr_layers) > 1:
                  for each in repr_layers:
                    result_list[each].append(result["mean_representations"][each].tolist())
                else:
                  result_list.append(result["mean_representations"][repr_layers[0]].tolist())

                """
                torch.save(
                    result,
                    args.output_file,
                )
                """
    # Output the result
    return result_list, esm_names


# Code for embedding extraction (prott5)
# Extract for every layer
# Requires too much resource, make it disable for now (Don't know the specific reason for now)
"""
from transformers import T5EncoderModel, T5Tokenizer
import gc

def prott5_embedding(sequence_list, model, tokenizer):
  #tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
  #model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
  gc.collect()
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  model = model.to(device)
  model = model.eval()
  sequences_Example = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequence_list]
  ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, padding=True)
  input_ids = torch.tensor(ids['input_ids']).to(device)
  attention_mask = torch.tensor(ids['attention_mask']).to(device)
  with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)
  embedding = embedding.last_hidden_state.cpu().numpy()
  features = [] 
  for seq_num in range(len(embedding)):
      seq_len = (attention_mask[seq_num] == 1).sum()
      seq_emd = embedding[seq_num][:seq_len-1]
      features.append(seq_emd)
  return features
"""

def sort_criterion(elem):
    return elem[1]

In [None]:
#@title Compute Embeddings of input sequences from ESM-1b
#@markdown - esm-1b only takes in sequences shorter than 1024 amino acids<br>Proteins longer than 1024 amino acids will be excluded
#model_name = "esm-1b, prott5" #@param ["esm-1b", "prott5", "esm-1b, prott5"]
#model_name = "esm-1b"
esm1b_query_embedding = None
prott5_query_embedding = None
# Run esm-1b
#if model_name == "esm-1b, prott5" or model_name == "esm-1b":
#model_esm, alphabet_esm = pretrained.load_model_and_alphabet("esm1b_t33_650M_UR50S")
esm1b_query_embedding, esm1b_names = esm_embedding("/content/esm1b_input.fasta", model_esm, alphabet_esm)

In [None]:
#@title Five Fold Validation with Machine Learning
#@markdown - Compute the layer showing best performance
n_neighbors = min(n_neighbors, cnt)
regressor = None
if Regressor == "Random Forest Regressor":
  regressor = RandomForestRegressor()
elif Regressor == "Linear Regression":
  regressor = LinearRegression()
elif Regressor == "Logistic Regression":
  regressor = LogisticRegression()
elif Regressor == "K Neighbors Regressor":
  regressor = KNeighborsRegressor(n_neighbors=n_neighbors)

# Let's do five fold validation!
# Don't do five fold validation if the train data size is smaller than 10 (quit & make the max_layer = 33)
max_layer = None
if do_distance:
  max_layer = len(esm1b_query_embedding) - 1
  regressor = KNeighborsRegressor(n_neighbors=1)
elif cnt < 10:
  warnings.warn("Warning: Size of data smaller than 10, skipping five fold validation, last layer of ESM-1b is chosen.")
  max_layer = len(esm1b_query_embedding) - 1
else:
  # Training data size is big enough. Let's do the five fold validation for {iteration} times!
  ff_average_esm = []
  #ff_average_t5 = []
  #for i in range(len(t5_embeddings_all_layers)):
    #ff_average_t5.append([])
  for i in range(len(esm1b_query_embedding)):
    ff_average_esm.append([])
  for ff in range(iteration):
    correlations_esm=[]
    mean_esm = []
    std_esm = []
    for layer in range(len(esm1b_query_embedding)):
      esm_embeddings = esm1b_query_embedding[layer]
      #pair the embedding, kcat, kcat_km and random shuffle, split into training and test set (esm)
      all_data = []
      for i in range(len(esm_embeddings)):
        all_data.append([esm_embeddings[i], name_value[esm1b_names[i]]])
      seed = np.random.randint(1,100000)
      np.random.seed(seed)
      np.random.shuffle(all_data)

      #Actual Five fold validation
      pearson_r = []
      for iter in range(5):
        training_set = []
        test_set = []
        for i in range(len(all_data)):
          if i % 5 == iter:
            test_set.append(all_data[i])
          else:
            training_set.append(all_data[i])

        train_x = []
        train_y = []
        for i in range(len(training_set)):
          train_x.append(training_set[i][0])
          train_y.append(training_set[i][1])

        test_x = []
        test_y = []
        for i in range(len(test_set)):
          test_x.append(test_set[i][0])
          test_y.append(test_set[i][1])
        regressor.fit(train_x, train_y)
        test_result = regressor.predict(test_x)
        pearson_r.append(np.corrcoef(np.array(test_y), np.array(test_result))[0, 1])
      pearson_r = np.array(pearson_r)
      for cor in pearson_r:
        correlations_esm.append(cor)
      mean_esm.append(np.mean(pearson_r))
      ff_average_esm[layer].append(np.mean(pearson_r))
      std_esm.append(np.std(pearson_r))
      #print("Layer: ", layer)
      #print("All cor: ", pearson_r, " mean: ", np.mean(pearson_r), " std: ", np.std(pearson_r))
    x_t5 = []
    x_error_t5 = []
    x_esm = []
    x_error_esm = []
    #num = 0
    #for i in range(len(t5_embeddings_all_layers)):
    #  x_error_t5.append(num)
    #  for j in range(5):
    #    x_t5.append(num)
    #  num += 1
    num = 0
    for i in range(len(esm1b_query_embedding)):
      x_error_esm.append(num)
      for j in range(5):
        x_esm.append(num)
      num += 1
    """
    fig1 = plt.figure(2 * ff)
    plt.scatter(x_t5, correlations_t5)
    plt.errorbar(x_error_t5, mean_t5, std_t5, linestyle='None', marker='^', ecolor='red', mfc='red', mec='red', ms=10)
    plt.title("Prottrans t5")
    plt.ylim(0,1)
    plt.xlabel("Layer")
    plt.ylabel("Correlation")
    """
    fig2 = plt.figure(2 * ff + 1)
    plt.scatter(x_esm, correlations_esm)
    plt.errorbar(x_error_esm, mean_esm, std_esm, linestyle='None', marker='^', ecolor='red', mfc='red', mec='red', ms=10)
    plt.title("esm1b")
    plt.ylim(0,1)
    plt.xlabel("Layer")
    plt.ylabel("Correlation")
    plt.show()
    print("%dth iteration"%(ff + 1))

  #get overall average of five trials, and report the best layer
  #ff_average_t5 = np.array(ff_average_t5)
  ff_average_esm = np.array(ff_average_esm)
  #ff_average_t5_final = np.mean(ff_average_t5, axis=1)
  ff_average_esm_final = np.mean(ff_average_esm, axis=1)
  #print("ff_average_t5 max: ", np.max(ff_average_t5_final), ", argmax: ", np.argmax(ff_average_t5_final))
  print("ff_average_esm max: ", np.max(ff_average_esm_final), ", argmax: ", np.argmax(ff_average_esm_final))
  max_layer = np.argmax(ff_average_esm_final)

In [None]:
#@title Search query sequences against Colabfold database and report the top ranking targets
amino_acid = ["G", "A", "L", "M", "K", "F", "W", "Q", "E", "S", "P", "V", "I", "C", "Y", "H", "R", "N", "D", "T"]
#target_number = 20 #@param = {type:"integer"}
m8_output = run_mmseqs2(sequences, "tmp")
with open("/content/esm1b_target.m8", 'w') as new:
  for each in m8_output:
    for line in each.split("\n"):
      if line != "":
        seq = line.split()[-1].upper()
        new_seq = ""
        for char in seq:
          if char in amino_acid:
            new_seq += char
        new.write("%s %s\n"%(line.split()[1], new_seq))
!sort -u -k2,2 esm1b_target.m8 | sort -k1,1 | awk 'BEGIN {cnt=2;name="";} {if (name == $1) {print $1"_colab_"cnt, $2;cnt += 1} else {print;name=$1;cnt=2}}' > esm1b_nonredundant.m8

# Gather the search result and make input file for esm1b
#searched_sequences = []
#searched_names = []
with open("/content/esm1b_nonredundant.m8") as f:
  with open("/content/esm1b_target_input.fasta", 'w') as new:
    line = f.readline()
    while line:
      if len(line[:-1].split()[1]) < 1023:
        new.write(">%s\n%s\n"%(line.split()[0], line[:-1].split()[1].upper()))
      line = f.readline()
#assert len(searched_names) == len(searched_sequences)

# Get the embeddings of the search result
target_embeddings, target_names = esm_embedding("/content/esm1b_target_input.fasta", model_esm, alphabet_esm, repr_layers=[max_layer])

# Retrain the regressor with all input data
esm_embeddings = esm1b_query_embedding[max_layer]
all_x = []
all_y = []
for i in range(len(esm_embeddings)):
  all_x.append(esm_embeddings[i])
  all_y.append(name_value[esm1b_names[i]])
regressor.fit(all_x, all_y)

# Get the prediction with regressor and sort
target_result = []
if do_distance:
  target_results = regressor.kneighbors(target_embeddings, n_neighbors=cnt)
  target_dist = np.sum(target_results[0], axis = 1)
  for i in range(len(target_dist)):
    target_result.append([target_names[i], target_dist[i]])
  target_result.sort(key = sort_criterion)
else:
  target_results = regressor.predict(target_embeddings)
  for i in range(len(target_results)):
      target_result.append([target_names[i], target_results[i]])
  target_result.sort(key=sort_criterion)
  target_result.reverse()


# Save the result
with open("/content/top_ranked_target_tmp.tsv",'w') as new:
  for i in range(len(target_result)):
    new.write("%s\t%f\n"%(target_result[i][0], target_result[i][1]))
# Get the sequence also
!awk 'BEGIN {OFS = "\t";} NR==FNR {if (substr($1, 1, 1) == ">") {name = substr($1, 2, length($1) - 1);getline;a[name]=$0;next;}} {if ($1 in a) {print $1, $2, a[$1]} else {print error}}' esm1b_target_input.fasta top_ranked_target_tmp.tsv > top_ranked_target_list.tsv

In [None]:
#@title Download the result file
#@markdown there is only one result file:<br>
#@markdown - <b>top_ranked_target_list.tsv</b>, containing names, predicted values, and sequences of the targets on each column respectively, sorted by predicted values.<br>
#@markdown - Result file might have names containing "\_colab_%d" at the end of the name. This happens when there are more than one sequences with same name.

#!zip -FSr enzyme_colabfold_search.result.zip "top_ranked_target.list"
files.download("top_ranked_target_list.tsv")

##What's happening in here
1.   Compute the embeddings of every layer of [ESM-1b](https://github.com/facebookresearch/esm) for each input sequence
2.   Conduct five fold validations n times with a regressor set by user (five and random forest regressor for default respectively), and choose the layer that shows the best performance
3.   Search through [colabfold database](https://colabfold.mmseqs.com/), using the input sequences as query
4.   Compute the embeddings of selected layer of ESM-1b for each search result
5.   Compute the predicted numerical values for each search result using a regressor trained by all input sequences
6.   Rank the search results by the predicted numerical value and download the result file (see the very bottom of the colab for detail about result file)