In [None]:
%%time
#@title install Dependencies
model_name = "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    print("installing esmfold...")
    # install libs
    os.system("pip install -q omegaconf \"pytorch_lightning<2\" \"torch<2\" biopython ml_collections einops py3Dmol")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    # install openfold
    commit = "6908936b68ae89f67755240e2f588c09ec31d4c8"
    os.system(f"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}")

    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Upload fasta file
from Bio import SeqIO
from google.colab import files

uploaded = files.upload()

seq_dict = dict()
seq_dict_short = dict()

for file in uploaded:
  for seq_rec in SeqIO.parse(f"./{file}", "fasta"):
    if len(seq_rec.seq) <= 400:
      seq_dict_short[seq_rec.id.replace("|","_")] = str(seq_rec.seq)
    else:
      seq_dict[seq_rec.id.replace("|","_")] = str(seq_rec.seq)

print(f"{len(seq_dict)} sequences will be predicted using Google GPU unit.. ")
print(f"{len(seq_dict_short)} sequences will be predicted using Meta API.. ")

In [None]:
#@title ##run **ESMFold**
%%time

output_folder = "/content/gdrive/My Drive/" #@param {type:"string"}
output_folder = output_folder.strip("/")

from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import os

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x):
  return hashlib.sha1(x.encode()).hexdigest()

alphabet_list = list(ascii_uppercase+ascii_lowercase)

for job in seq_dict:
  jobname = job
  jobname = re.sub(r'\W+', '', jobname)[:50]

  sequence = seq_dict[job]
  sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
  sequence = re.sub(":+",":",sequence)
  sequence = re.sub("^[:]+","",sequence)
  sequence = re.sub("[:]+$","",sequence)
  copies = 1
  if copies == "" or copies <= 0: copies = 1
  sequence = ":".join([sequence] * copies)
  num_recycles = 3
  chain_linker = 25

  ID = jobname+"_"+get_hash(sequence)[:5]
  seqs = sequence.split(":")
  lengths = [len(s) for s in seqs]
  length = sum(lengths)
  print("length",length)

  u_seqs = list(set(seqs))
  if len(seqs) == 1: mode = "mono"
  elif len(u_seqs) == 1: mode = "homo"
  else: mode = "hetero"

  if "model" not in dir():
    import torch
    model = torch.load("esmfold.model")
    model.eval().cuda().requires_grad_(False)

  # optimized for Tesla T4
  if length > 700:
    model.set_chunk_size(64)
  else:
    model.set_chunk_size(128)

  torch.cuda.empty_cache()
  output = model.infer(sequence,
                      num_recycles=num_recycles,
                      chain_linker="X"*chain_linker,
                      residue_index_offset=512)

  pdb_str = model.output_to_pdb(output)[0]
  output = tree_map(lambda x: x.cpu().numpy(), output)
  ptm = output["ptm"][0]
  plddt = output["plddt"][0,...,1].mean()
  O = parse_output(output)
  print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')
  os.system(f"mkdir -p {ID}")
  prefix = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_default"
  np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
  with open(f"{prefix}.pdb","w") as out:
    out.write(pdb_str)

  with open(f'/{output_folder}/{job}.pdb', 'w') as f:
    f.write(pdb_str)

#Short sequences prediction
done = set()
while len(done) != len(seq_dict_short):

  for seq_ID in seq_dict_short:
    os.system(f"curl -X POST --data '{seq_dict_short[seq_ID]}' https://api.esmatlas.com/foldSequence/v1/pdb/ > {seq_ID}.pdb")

    with open(f"{seq_ID}.pdb", "r") as file:
      data = file.read()
      if data.startswith("HEADER"):
        with open(f'/{output_folder}/{seq_ID}.pdb', 'w') as f:
          f.write(data)
        done.add(seq_ID)
        print(f"{seq_ID} done..")