# **SaprotHub: Making Protein Modeling Accessible to All Biologists**

<a href="https://www.biorxiv.org/content/10.1101/2024.05.24.595648v3"><img src="https://img.shields.io/badge/Paper-bioRxiv-green" style="max-width: 100%;"></a>
<a href="https://huggingface.co/SaProtHub"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-red?label=SaprotHub" style="max-width: 100%;"></a>
<a href="https://github.com/westlake-repl/SaprotHub"><img src="https://img.shields.io/badge/Github-black?logo=github" style="max-width: 100%;"></a>
<a href="https://theopmc.github.io/"><img src="https://img.shields.io/badge/Website-OPMC-yellow" style="max-width: 100%;"></a>
<a href="https://cbirt.net/no-coding-required-saprothub-brings-protein-modeling-to-every-biologist/" alt="blog"><img src="https://img.shields.io/badge/Blog-Medium-purple" /></a>
<a href="https://x.com/sokrypton/status/1795525127653986415"><img src="https://img.shields.io/badge/Twitter-blue?logo=twitter" style="max-width: 100%;"></a>


This is **ColabSaprot**, the Colab version of [SaProt](https://github.com/westlake-repl/SaProt), a pre-trained protein language model designed for various downstream protein tasks.

**ColabSaprot** is a platform where **Protein Language Models(PLMs)** are more accessible and user-friendly for biologists, enabling effortless model training and sharing within the scientific community.

We've established the **SaprotHub**([website](https://huggingface.co/SaProtHub), [paper](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v3)) for storing and sharing models and datasets, where you can explore extensive collections for specific protein prediction tasks.

We hope ColabSaprot and SaprotHub can contribute to advancing biological research, fostering collaboration, and accelerating discoveries in the field. You can access [our paper](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v3) for further details.

Check these videos ([training](https://www.youtube.com/watch?v=r42z1hvYKfw), [predicting](https://www.youtube.com/watch?v=N5VMBwM_ukQ)) to see how to use  ColabSaprot.

Joining [**OPMC**](https://theopmc.github.io/) as an author of SaprotHub.

ColabSaprot supports hundreds of [protein prediction tasks](https://github.com/westlake-repl/SaProtHub/blob/main/task_list.md).







In [240]:
#@title **First, click the run button to install SaProt**

#@markdown (Please waiting for 2-8 minutes to install...)
################################################################################
########################### install saprot #####################################
################################################################################
%load_ext autoreload
%autoreload 2

import os
# Check whether the server is local or from google cloud
root_dir = os.getcwd()

from google.colab import output
# output.enable_custom_widget_manager()

try:
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")
  import saprot
  print("SaProt is installed successfully!")
  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")

except ImportError:
  print("Installing SaProt...")
  os.system(f"rm -rf {root_dir}/SaprotHub")
  # !rm -rf /content/SaprotHub/

  !git clone https://github.com/westlake-repl/SaprotHub.git

  # !pip install /content/SaprotHub/saprot-0.4.7-py3-none-any.whl
  os.system(f"pip install -r {root_dir}/SaprotHub/requirements.txt")
  # !pip install -r /content/SaprotHub/requirements.txt

  os.system(f"pip install {root_dir}/SaprotHub")


  os.system(f"mkdir -p {root_dir}/SaprotHub/LMDB")
  os.system(f"mkdir -p {root_dir}/SaprotHub/bin")
  os.system(f"mkdir -p {root_dir}/SaprotHub/output")
  os.system(f"mkdir -p {root_dir}/SaprotHub/datasets")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/token_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/structures")
  # !mkdir -p /content/SaprotHub/LMDB
  # !mkdir -p /content/SaprotHub/bin
  # !mkdir -p /content/SaprotHub/output
  # !mkdir -p /content/SaprotHub/datasets
  # !mkdir -p /content/SaprotHub/adapters/classification/Local
  # !mkdir -p /content/SaprotHub/adapters/regression/Local
  # !mkdir -p /content/SaprotHub/adapters/token_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_regression/Local
  # !mkdir -p /content/SaprotHub/structures

  # !pip install gdown==v4.6.3 --force-reinstall --quiet
  # os.system(
  #   f"wget 'https://drive.usercontent.google.com/download?id=1B_9t3n_nlj8Y3Kpc_mMjtMdY0OPYa7Re&export=download&authuser=0' -O {root_dir}/SaprotHub/bin/foldseek"
  # )

  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")
  # !chmod +x /content/SaprotHub/bin/foldseek
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")

  # !mv /content/SaprotHub/ColabSaprotSetup/foldseek /content/SaprotHub/bin/

################################################################################
################################################################################
################################## global ######################################
################################################################################
################################################################################

import ipywidgets
import pandas as pd
import torch
import numpy as np
import lmdb
import base64
import copy
import os
import json
import zipfile
import yaml
import argparse
import pprint
import subprocess
import py3Dmol
import matplotlib.pyplot as plt
import shutil
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

from loguru import logger
from easydict import EasyDict
from colorama import init, Fore, Back, Style
from IPython.display import clear_output
from huggingface_hub import snapshot_download
from ipywidgets import HTML
from IPython.display import display
from google.colab import widgets
from google.colab import files
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from transformers import AutoTokenizer, EsmForProteinFolding, EsmTokenizer
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from string import ascii_uppercase,ascii_lowercase
from saprot.utils.mpr import MultipleProcessRunnerSimplifier
from saprot.data.parse import get_chain_ids
from saprot.scripts.training import my_load_model
from safetensors import safe_open

DATASET_HOME = Path(f'{root_dir}/SaprotHub/datasets')
ADAPTER_HOME = Path(f'{root_dir}/SaprotHub/adapters')
STRUCTURE_HOME = Path(f"{root_dir}/SaprotHub/structures")
LMDB_HOME = Path(f'{root_dir}/SaprotHub/LMDB')
OUTPUT_HOME = Path(f'{root_dir}/SaprotHub/output')
UPLOAD_FILE_HOME = Path(f'{root_dir}/SaprotHub/upload_files')
FOLDSEEK_PATH = Path(f"{root_dir}/SaprotHub/bin/foldseek")
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"

data_type_list = ["Single AA Sequence",
                  "Single SA Sequence",
                  "Single UniProt ID",
                  "Single PDB/CIF Structure",
                  "Multiple AA Sequences",
                  "Multiple SA Sequences",
                  "Multiple UniProt IDs",
                  "Multiple PDB/CIF Structures",
                  "SaprotHub Dataset",
                  "A pair of AA Sequences",
                  "A pair of SA Sequences",
                  "A pair of UniProt IDs",
                  "A pair of PDB/CIF Structures",
                  "Multiple pairs of AA Sequences",
                  "Multiple pairs of SA Sequences",
                  "Multiple pairs of UniProt IDs",
                  "Multiple pairs of PDB/CIF Structures",]

data_type_list_single = [
    "Single AA Sequence",
    "Single SA Sequence",
    "Single UniProt ID",
    "Single PDB/CIF Structure",
    "A pair of AA Sequences",
    "A pair of SA Sequences",
    "A pair of UniProt IDs",
    "A pair of PDB/CIF Structures",]

data_type_list_multiple = [
    "Multiple AA Sequences",
    "Multiple SA Sequences",
    "Multiple UniProt IDs",
    "Multiple PDB/CIF Structures",
    "Multiple pairs of AA Sequences",
    "Multiple pairs of SA Sequences",
    "Multiple pairs of UniProt IDs",
    "Multiple pairs of PDB/CIF Structures",]

task_type_dict = {
  "Protein-level Classification": "classification",
  "Residue-level Classification" : "token_classification",
  "Protein-level Regression" : "regression",
  "Protein-protein Classification": "pair_classification",
  "Protein-protein Regression": "pair_regression",
}
model_type_dict = {
  "classification" : "saprot/saprot_classification_model",
  "token_classification" : "saprot/saprot_token_classification_model",
  "regression" : "saprot/saprot_regression_model",
  "pair_classification" : "saprot/saprot_pair_classification_model",
  "pair_regression" : "saprot/saprot_pair_regression_model",
}
dataset_type_dict = {
  "classification": "saprot/saprot_classification_dataset",
  "token_classification" : "saprot/saprot_token_classification_dataset",
  "regression": "saprot/saprot_regression_dataset",
  "pair_classification" : "saprot/saprot_pair_classification_dataset",
  "pair_regression" : "saprot/saprot_pair_regression_dataset",
}
training_data_type_dict = {
  "Single AA Sequence": "AA",
  "Single SA Sequence": "SA",
  "Single UniProt ID": "SA",
  "Single PDB/CIF Structure": "SA",
  "Multiple AA Sequences": "AA",
  "Multiple SA Sequences": "SA",
  "Multiple UniProt IDs": "SA",
  "Multiple PDB/CIF Structures": "SA",
  "SaprotHub Dataset": "SA",
  "A pair of AA Sequences": "AA",
  "A pair of SA Sequences": "SA",
  "A pair of UniProt IDs": "SA",
  "A pair of PDB/CIF Structures": "SA",
  "Multiple pairs of AA Sequences": "AA",
  "Multiple pairs of SA Sequences": "SA",
  "Multiple pairs of UniProt IDs": "SA",
  "Multiple pairs of PDB/CIF Structures": "SA",
}


class font:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'

    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    RESET = '\033[0m'


################################################################################
############################### adapters #######################################
################################################################################
def get_adapters_list(task_type=None):

    adapters_list = []

    if task_type:
      for file_path in (ADAPTER_HOME / task_type).glob('**/adapter_config.json'):
        adapters_list.append(file_path.relative_to(ADAPTER_HOME / task_type).parent)
    else:
      for file_path in ADAPTER_HOME.glob('**/adapter_config.json'):
        adapters_list.append(file_path.relative_to(ADAPTER_HOME).parent)

    return adapters_list

def adapters_text(adapters_list):
  input = ipywidgets.Text(
    value=None,
    placeholder='Enter SaprotHub Model ID',
    # description='Selected:',
    disabled=False)
  input.layout.width = '500px'
  display(input)

  return input

def adapters_dropdown(adapters_list):
  dropdown = ipywidgets.Dropdown(
    # options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    options=adapters_list,
    value=None,
    placeholder='Select a Local Model here',
    # description='Selected:',
    disabled=False)
  dropdown.layout.width = '500px'
  display(dropdown)

  return dropdown

def adapters_combobox(adapters_list):
  combobox = ipywidgets.Combobox(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Enter SaprotHub Model repository id or select a Local Model here',
    # description='Selected:',
    disabled=False)
  combobox.layout.width = '500px'
  display(combobox)

  return combobox

def adapters_selectmultiple(adapters_list):
  selectmulitiple = ipywidgets.SelectMultiple(
  # options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
  options=adapters_list,
  value=[],
  #rows=10,
  placeholder='Select multiple models',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(selectmulitiple)

  return selectmulitiple

def adapters_textmultiple(adapters_list):
  textmultiple = ipywidgets.Text(
  value=None,
  placeholder='Enter multiple SaprotHub Model IDs, separated by commas.',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(textmultiple)

  return textmultiple


def select_adapter_from(task_type, use_model_from):
  adapters_list = get_adapters_list(task_type)

  if use_model_from == 'Trained by yourself on ColabSaprot':
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    return adapters_dropdown(adapters_list)

  elif use_model_from == 'Shared by peers on SaprotHub':
    print(Fore.BLUE+"SaprotHub Model:"+Style.RESET_ALL)
    return adapters_text(adapters_list)

  elif use_model_from == "Saved in your local computer":
    print(Fore.BLUE+"Click the button to upload the \"Model-<task_name>-<model_size>.zip\" file of your Model:"+Style.RESET_ALL)
    # 1. upload model.zip
    if task_type:
      adapter_upload_path = ADAPTER_HOME / task_type / "Local"
    else:
      adapter_upload_path = ADAPTER_HOME / "Local"

    adapter_zip_path = upload_file(adapter_upload_path)
    adapter_path = adapter_upload_path / adapter_zip_path.stem
    # 2. unzip model.zip
    with zipfile.ZipFile(adapter_zip_path, 'r') as zip_ref:
        zip_ref.extractall(adapter_path)
    os.remove(adapter_zip_path)
    # 3. check adapter_config.json
    adapter_config_path = adapter_path / "adapter_config.json"
    assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"

    # # 4. move to correct folder
    # num_labels, task_type = get_num_labels_and_task_type_by_adapter(adapter_path)
    # shutil.move(adapter_path, ADAPTER_HOME / task_type)

    return EasyDict({"value":  f"Local/{adapter_zip_path.stem}"})

  elif use_model_from == "Multi-models on ColabSaprot":
    # 1. select the list of adapters
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    print(Fore.BLUE+f"Multiple values can be selected with \"shift\" and/or \"ctrl\" (or \"command\") pressed and mouse clicks or arrow keys."+Style.RESET_ALL)
    return adapters_selectmultiple(adapters_list)

  elif use_model_from == "Multi-models on SaprotHub":
    # 1. enter the list of adapters
    print(Fore.BLUE+f"SaprotHub Model IDs, separated by commas ({task_type}):"+Style.RESET_ALL)
    return adapters_textmultiple(adapters_list)



################################################################################
########################### download dataset ###################################
################################################################################
def download_dataset(task_name):
  import gdown
  import tarfile

  filepath = LMDB_HOME / f"{task_name}.tar.gz"
  download_links = {
    "ClinVar" : "https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo",
    "DeepLoc_cls2" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "DeepLoc_cls10" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "EC" : "https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ",
    "GO_BP" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_CC" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_MF" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "HumanPPI" : "https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX",
    "MetalIonBinding" : "https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3",
    "ProteinGym" : "https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD",
    "Thermostability" : "https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz",
  }

  try:
    gdown.download(download_links[task_name], str(filepath), quiet=False)
    with tarfile.open(filepath, 'r:gz') as tar:
      tar.extractall(path=str(LMDB_HOME))
      print(f"Extracted: {filepath}")
  except Exception as e:
    raise RuntimeError("The dataset has not prepared.")

################################################################################
############################# upload file ######################################
################################################################################
def upload_file(upload_path):
  upload_path = Path(upload_path)
  upload_path.mkdir(parents=True, exist_ok=True)
  basepath = Path().resolve()
  try:
    uploaded = files.upload()
    filenames = []
    for filename in uploaded.keys():
      filenames.append(filename)
      shutil.move(basepath / filename, upload_path / filename)
    if len(filenames) == 0:
      logger.info("The uploading process has been interrupted by the user.")
      raise RuntimeError("The uploading process has been interrupted by the user.")
  except Exception as e:
    logger.error("Upload file fail! Please click the button to run again.")
    raise(e)

  return upload_path / filenames[0]

################################################################################
############################ upload dataset ####################################
################################################################################

def read_csv_dataset(uploaded_csv_path):
  df = pd.read_csv(uploaded_csv_path)
  df.columns = df.columns.str.lower()
  return df

def check_column_label_and_stage(csv_dataset_path):
  df = read_csv_dataset(csv_dataset_path)
  assert {'label', 'stage'}.issubset(df.columns), f"Make sure your CSV dataset includes both `label` and `stage` columns!\nCurrent columns: {df.columns}"
  column_values = set(df['stage'].unique())
  assert all(value in column_values for value in ['train', 'valid', 'test']), f"Ensure your dataset includes samples for all three stages: `train`, `valid` and `test`.\nCurrent columns: {df.columns}"

def get_data_type(csv_dataset_path):
  # AA, SA, Pair AA, Pair SA
  df = read_csv_dataset(csv_dataset_path)

  # AA, SA
  if 'sequence' in df.columns:
    second_token = df.loc[0, 'sequence'][1]
    if second_token in aa_set:
      return "Multiple AA Sequences"
    elif second_token in foldseek_struc_vocab:
      return "Multiple SA Sequences"
    else:
      raise RuntimeError(f"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.")

  # Pair AA, Pair SA
  elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:
    second_token = df.loc[0, 'sequence_1'][1]
    if second_token in aa_set:
      return "Multiple pairs of AA Sequences"
    elif second_token in foldseek_struc_vocab:
      return "Multiple pairs of SA Sequences"
    else:
      raise RuntimeError(f"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.")

  else:
      raise RuntimeError(f"The data type of the dataset({csv_dataset_path}) should be one of the following types: Multiple AA Sequences, Multiple SA Sequences, Multiple pairs of AA Sequences, Multiple pairs of SA Sequences")

def check_task_type_and_data_type(original_task_type, data_type):
  if "Protein-protein" in original_task_type:
    assert data_type == "SaprotHub Dataset" or "pair" in data_type, f"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please use Pair Sequence Datset for {original_task_type} task!"
  else:
    assert "pair" not in data_type, f"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please avoid using the Pair Sequence Dataset({data_type}) for the {original_task_type} task!"

def input_raw_data_by_data_type(data_type):

  # 0-2. 0. Single AA Sequence, 1. Single SA Sequence, 2. Single UniProt ID
  if data_type in data_type_list[:3]:
    input_seq = ipywidgets.Text(
      value=None,
      placeholder=f'Enter {data_type} here',
      disabled=False)
    input_seq.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_seq)
    return input_seq

  # 3. Single PDB/CIF Structure
  elif data_type == 'Single PDB/CIF Structure':
    input_chain = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain here',
      disabled=False)
    input_chain.layout.width = '500px'
    print(Fore.BLUE+"Chain (to be extracted from the structure):"+Style.RESET_ALL)
    display(input_chain)

    print(Fore.BLUE+"Click to upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path = upload_file(STRUCTURE_HOME)
    print(input_chain)
    return pdb_file_path.stem, input_chain

  # 4-7 & 13-16. Multiple Sequences
  elif data_type in data_type_list_multiple:
    print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
    uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
    print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
    print("="*100)

    if data_type in ['Multiple PDB/CIF Structures', 'Multiple pairs of PDB/CIF Structures']:
      # upload and unzip PDB files
      print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
      pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
      if pdb_zip_path.suffix != ".zip":
        logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
        raise RuntimeError("The data type does not match.")
      print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
      print("="*100)

      import zipfile
      with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
        zip_ref.extractall(STRUCTURE_HOME)

    return uploaded_csv_path

  # 8. SaprotHub Dataset
  elif data_type == "SaprotHub Dataset":
    input_repo_id = ipywidgets.Text(
      value=None,
      placeholder=f'Copy and paste the SaprotHub Dataset ID here',
      disabled=False)
    input_repo_id.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_repo_id)
    return input_repo_id

  # 9-11. A pair of seq
  elif data_type in ["A pair of AA Sequences", "A pair of SA Sequences", "A pair of UniProt IDs"]:
    print()

    seq_type = data_type[len("A pair of "):-1]

    input_seq1 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 1 here',
      disabled=False)
    input_seq1.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 1:"+Style.RESET_ALL)
    display(input_seq1)

    input_seq2 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 2 here',
      disabled=False)
    input_seq2.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 2:"+Style.RESET_ALL)
    display(input_seq2)

    return (input_seq1, input_seq2)

  # 12. Pair Single PDB/CIF Structure
  elif data_type == 'A pair of PDB/CIF Structures':
    print("Please provide the structure type, chain and your structure file.")

    dropdown_type1 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type1.layout.width = '500px'
    print(Fore.BLUE+"The first structure type:"+Style.RESET_ALL)
    display(dropdown_type1)

    input_chain1 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the first structure here',
      disabled=False)
    input_chain1.layout.width = '500px'
    print(Fore.BLUE+"Chain of the first structure:"+Style.RESET_ALL)
    display(input_chain1)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path1 = upload_file(STRUCTURE_HOME)


    dropdown_type2 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type2.layout.width = '500px'
    print(Fore.BLUE+"The second structure type:"+Style.RESET_ALL)
    display(dropdown_type2)

    input_chain2 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the second structure here',
      disabled=False)
    input_chain2.layout.width = '500px'
    print(Fore.BLUE+"Chain of the second structure:"+Style.RESET_ALL)
    display(input_chain2)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path2 = upload_file(STRUCTURE_HOME)
    return (pdb_file_path1.stem, dropdown_type1, input_chain1, pdb_file_path2.stem, dropdown_type2, input_chain2)

def get_SA_sequence_by_data_type(data_type, raw_data):

  # Multiple sequences
  # raw_data = upload_files/xxx.csv

  # 8. SaprotHub Dataset
  if data_type == "SaprotHub Dataset":
    input_repo_id = raw_data
    REPO_ID = input_repo_id.value

    if REPO_ID.startswith('/'):
      return Path(REPO_ID)

    snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=DATASET_HOME / REPO_ID)
    csv_dataset_path = DATASET_HOME / REPO_ID / 'dataset.csv'
    assert csv_dataset_path.exists(), f"Can't find {csv_dataset_path}"
    protein_df = read_csv_dataset(csv_dataset_path)

    data_type = get_data_type(csv_dataset_path)

    return get_SA_sequence_by_data_type(data_type, csv_dataset_path)

    # # AA, SA
    # if data_type == "Multiple AA Sequences":
    #   for index, value in protein_df['sequence'].items():
    #     sa_seq = ''
    #     for aa in value:
    #       sa_seq += aa + '#'
    #     protein_df.at[index, 'sequence'] = sa_seq

    # # Pair AA, Pair SA
    # elif data_type in ["Multiple pairs of AA Sequences", "Multiple pairs of SA Sequences"]:
    #   for i in ['1', '2']:
    #     if data_type == "Multiple pairs of AA Sequences":
    #       for index, value in protein_df[f'sequence_{i}'].items():
    #         sa_seq = ''
    #         for aa in value:
    #           sa_seq += aa + '#'
    #         protein_df.at[index, f'sequence_{i}'] = sa_seq

    #     protein_df[f'name_{i}'] = f'name_{i}'
    #     protein_df[f'chain_{i}'] = 'A'

    # protein_df.to_csv(csv_dataset_path, index=None)

    # return csv_dataset_path

  elif data_type in data_type_list_multiple:
    uploaded_csv_path = raw_data
    csv_dataset_path = DATASET_HOME / uploaded_csv_path.name
    protein_df = read_csv_dataset(uploaded_csv_path)

    if 'pair' in data_type:
      assert {'sequence_1', 'sequence_2'}.issubset(protein_df.columns), f"The CSV dataset ({uploaded_csv_path}) must contain `sequence_1` and `sequence_2` columns. \n Current columns:{protein_df.columns}"
    else:
      assert 'sequence' in protein_df.columns, f"The CSV Dataset({uploaded_csv_path}) must contain a `sequence` column. \n Current columns:{protein_df.columns}"

    # 4. Multiple AA Sequences
    if data_type == 'Multiple AA Sequences':
      for index, value in protein_df['sequence'].items():
        sa_seq = ''
        for aa in value:
          sa_seq += aa + '#'
        protein_df.at[index, 'sequence'] = sa_seq

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 5. Multiple SA Sequences
    elif data_type == 'Multiple SA Sequences':
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 6. Multiple UniProt IDs
    elif data_type == 'Multiple UniProt IDs':
      protein_list = protein_df.loc[:, 'sequence'].tolist()
      uniprot2pdb(protein_list)
      protein_list = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list]
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      outputs = mprs.run()

      protein_df['sequence'] = [output.split("\t")[1] for output in outputs]
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 7. Multiple PDB/CIF Structures
    elif data_type == 'Multiple PDB/CIF Structures':
      # protein_list = [(uniprot_id, type, chain), ...]
      # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
      # uniprot2pdb(protein_list)
      protein_list = []
      for row_tuple in protein_df.itertuples(index=False):
        assert row_tuple.type in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
        protein_list.append(row_tuple)
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      outputs = mprs.run()

      protein_df['sequence'] = [output.split("\t")[1] for output in outputs]
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 13. Pair Multiple AA Sequences
    elif data_type == "Multiple pairs of AA Sequences":
      for i in ['1', '2']:
        for index, value in protein_df[f'sequence_{i}'].items():
          sa_seq = ''
          for aa in value:
            sa_seq += aa + '#'
          protein_df.at[index, f'sequence_{i}'] = sa_seq

        protein_df[f'name_{i}'] = f'name_{i}'
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 14. Pair Multiple SA Sequences
    elif data_type == "Multiple pairs of SA Sequences":
      for i in ['1', '2']:
        protein_df[f'name_{i}'] = f'name_{i}'
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 15. Pair Multiple UniProt IDs
    elif data_type == "Multiple pairs of UniProt IDs":
      for i in ['1', '2']:
        protein_list = protein_df.loc[:, f'sequence_{i}'].tolist()
        uniprot2pdb(protein_list)
        protein_df[f'name_{i}'] = protein_list
        protein_list = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list]
        mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
        outputs = mprs.run()

        protein_df[f'sequence_{i}'] = [output.split("\t")[1] for output in outputs]
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    elif data_type ==  "Multiple pairs of PDB/CIF Structures":
      # columns: sequence_1, sequence_2, type_1, type_2, chain_1, chain_2, label, stage

      # protein_list = [(uniprot_id, type, chain), ...]
      # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
      # uniprot2pdb(protein_list)

      for i in ['1', '2']:
        protein_list = []
        for index, row in protein_df.iterrows():
          assert row[f"type_{i}"] in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
          row_tuple = (row[f"sequence_{i}"], row[f"type_{i}"], row[f"chain_{i}"])
          protein_list.append(row_tuple)
        mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
        outputs = mprs.run()

        # add name column, del type column
        protein_df[f'name_{i}'] = protein_df[f'sequence_{i}'].apply(lambda x: x.split('.')[0])
        protein_df.drop(f"type_{i}", axis=1, inplace=True)
        protein_df[f'sequence_{i}'] = [output.split("\t")[1] for output in outputs]

      # columns: name_1, name_2, chain_1, chain_2, sequence_1, sequence_2, label, stage
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

  else:
    # 0. Single AA Sequence
    if data_type == 'Single AA Sequence':
      input_seq = raw_data
      aa_seq = input_seq.value

      sa_seq = ''
      for aa in aa_seq:
          sa_seq += aa + '#'
      return sa_seq

    # 1. Single SA Sequence
    elif data_type == 'Single SA Sequence':
      input_seq = raw_data
      sa_seq = input_seq.value

      return sa_seq

    # 2. Single UniProt ID
    elif data_type == 'Single UniProt ID':
      input_seq = raw_data
      uniprot_id = input_seq.value


      protein_list = [(uniprot_id, "AF2", "A")]
      uniprot2pdb([protein_list[0][0]])
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      seqs = mprs.run()
      sa_seq = seqs[0].split('\t')[1]
      return sa_seq

    # 3. Single PDB/CIF Structure
    elif data_type == 'Single PDB/CIF Structure':
      uniprot_id = raw_data[0]
      chain = raw_data[1].value

      protein_list = [(uniprot_id, chain)]
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      seqs = mprs.run()
      assert len(seqs)>0, "Unable to convert to SA sequence. Please check the `type`, `chain`, and `.pdb/.cif file`."
      sa_seq = seqs[0].split('\t')[1]
      return sa_seq

    # 9. Pair Single AA Sequences
    elif data_type == "A pair of AA Sequences":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 10. Pair Single SA Sequences
    elif data_type ==  "A pair of SA Sequences":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 11. Pair Single UniProt IDs
    elif data_type ==  "A pair of UniProt IDs":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 12. Pair Single PDB/CIF Structure
    elif data_type == "A pair of PDB/CIF Structures":
      uniprot_id1 = raw_data[0]
      struc_type1 = raw_data[1].value
      chain1 = raw_data[2].value

      protein_list1 = [(uniprot_id1, struc_type1, chain1)]
      mprs1 = MultipleProcessRunnerSimplifier(protein_list1, pdb2sequence, n_process=2, return_results=True)
      seqs1 = mprs1.run()
      sa_seq1 = seqs1[0].split('\t')[1]

      uniprot_id2 = raw_data[3]
      struc_type2 = raw_data[4].value
      chain2 = raw_data[5].value

      protein_list2 = [(uniprot_id2, struc_type2, chain2)]
      mprs2 = MultipleProcessRunnerSimplifier(protein_list2, pdb2sequence, n_process=2, return_results=True)
      seqs2 = mprs2.run()
      sa_seq2 = seqs2[0].split('\t')[1]
      return sa_seq1, sa_seq2




################################################################################
########################## Download predicted structures #######################
################################################################################
def uniprot2pdb(uniprot_ids, nprocess=20):
  from saprot.utils.downloader import AlphaDBDownloader

  os.makedirs(STRUCTURE_HOME, exist_ok=True)
  af2_downloader = AlphaDBDownloader(uniprot_ids, "pdb", save_dir=STRUCTURE_HOME, n_process=20)
  af2_downloader.run()



################################################################################
############### Form foldseek sequences by multiple processes ##################
################################################################################
# def pdb2sequence(process_id, idx, uniprot_id, writer):
#   from saprot.utils.foldseek_util import get_struc_seq

#   try:
#     pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
#     cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
#     if Path(pdb_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]
#     if Path(cif_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, cif_path, ["A"], process_id=process_id)["A"][-1]

#     writer.write(f"{uniprot_id}\t{seq}\n")
#   except Exception as e:
#     print(f"Error: {uniprot_id}, {e}")

# clear_output(wait=True)
# print("Installation finished!")

def pdb2sequence(process_id, idx, row_tuple, writer):

  # print("="*100)
  # print(row_tuple)
  # print("="*100)
  uniprot_id = row_tuple[0].split('.')[0]     #
  chain = row_tuple[1]
  plddt_mask= True

  from saprot.utils.foldseek_util import get_struc_seq
  try:
    pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
    cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
    if Path(pdb_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    elif Path(cif_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    else:
      raise BaseException(f"The {uniprot_id}.pdb/{uniprot_id}.cif file doesn't exists!")
    writer.write(f"{uniprot_id}\t{seq}\n")

  except Exception as e:
    print(f"Error: {uniprot_id}, {e}")


pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
          "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
          "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
          "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
          "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


def convert_outputs_to_pdb(outputs):
	final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
	outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
	final_atom_positions = final_atom_positions.cpu().numpy()
	final_atom_mask = outputs["atom37_atom_exists"]
	pdbs = []
	outputs["plddt"] *= 100

	for i in range(outputs["aatype"].shape[0]):
		aa = outputs["aatype"][i]
		pred_pos = final_atom_positions[i]
		mask = final_atom_mask[i]
		resid = outputs["residue_index"][i] + 1
		pred = OFProtein(
		    aatype=aa,
		    atom_positions=pred_pos,
		    atom_mask=mask,
		    residue_index=resid,
		    b_factors=outputs["plddt"][i],
		    chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
		)
		pdbs.append(to_pdb(pred))
	return pdbs


# This function is copied from ColabFold!
def show_pdb(path, show_sidechains=False, show_mainchains=False, color="lddt"):
  file_type = str(path).split(".")[-1]
  if file_type == "cif":
    file_type == "mmcif"

  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(path,'r').read(),file_type)

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(get_chain_ids(path))
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view


def plot_plddt_legend(dpi=100):
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt


################################################################################
###############   Download file to local computer   ##################
################################################################################
def file_download(path: str):
  with open(path, "rb") as r:
    res = r.read()

  #FILE
  filename = os.path.basename(path)
  b64 = base64.b64encode(res)
  payload = b64.decode()

  #BUTTONS
  html_buttons = '''<html>
  <head>
  <meta name="viewport" content="width=device-width, initial-scale=1">
  </head>
  <body>
  <a download="{filename}" href="data:text/csv;base64,{payload}" download>
  <button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button>
  </a>
  </body>
  </html>
  '''

  html_button = html_buttons.format(payload=payload,filename=filename)
  display(HTML(html_button))

  # Automatically download file if the server is from google cloud.
  if root_dir == "/content":
    files.download(path)

################################################################################
############################ MODEL INFO #######################################
################################################################################
def get_base_model(adapter_path):
  adapter_config = Path(adapter_path) / "adapter_config.json"
  with open(adapter_config, 'r') as f:
    adapter_config_dict = json.load(f)
    base_model = adapter_config_dict['base_model_name_or_path']
    if 'SaProt_650M_AF2' in base_model:
      base_model = "westlake-repl/SaProt_650M_AF2"
    elif 'SaProt_35M_AF2' in base_model:
      base_model = "westlake-repl/SaProt_35M_AF2"
    else:
      raise RuntimeError("Please ensure the base model is \"SaProt_650M_AF2\" or \"SaProt_35M_AF2\"")
  return base_model

def check_training_data_type(adapter_path, data_type):
  metadata_path = Path(adapter_path) / "metadata.json"
  if metadata_path.exists():
    with open(metadata_path, 'r') as f:
      metadata = json.load(f)
      required_training_data_type = metadata['training_data_type']
  else:
    required_training_data_type = "SA"

  if (required_training_data_type == "AA") and ("AA" not in data_type):
    print(Fore.RED+f"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with AA sequences."+Style.RESET_ALL)
    print(Fore.RED+f"The current data type ({data_type}) includes structural information, which will not be used for predictions."+Style.RESET_ALL)
    print()
    print('='*100)
  elif (required_training_data_type == "SA") and ("AA" in data_type):
    print(Fore.RED+f"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with SA sequences."+Style.RESET_ALL)
    print(Fore.RED+f"The current data type ({data_type}) does not include structural information, which may lead to weak prediction performance."+Style.RESET_ALL)
    print(Fore.RED+f"If you only have the amino acid sequence, we strongly recommend using AF2 to predict the structure and generate a PDB file before prediction."+Style.RESET_ALL)
    print()
    print('='*100)

  return required_training_data_type

def mask_struc_token(sequence):
    return ''.join('#' if i % 2 == 1 and char.islower() else char for i, char in enumerate(sequence))

def get_num_labels_by_adapter(adapter_path):
    adapter_path = Path(adapter_path)

    if (adapter_path / 'adapter_model.safetensors').exists():
        file_path = adapter_path / 'adapter_model.safetensors'
        with safe_open(file_path, framework="pt") as f:
          if 'base_model.model.classifier.out_proj.bias' in f.keys():
              tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')
          elif 'base_model.model.classifier.bias' in f.keys():
              tensor = f.get_tensor('base_model.model.classifier.bias')
          else:
              raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    elif (adapter_path / 'adapter_model.bin').exists():
      file_path = adapter_path / 'adapter_model.bin'
      state_dict = torch.load(file_path)
      if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():
        tensor = state_dict['base_model.model.classifier.out_proj.bias']
      elif 'base_model.model.classifier.bias' in f.keys():
        tensor = state_dict['base_model.model.classifier.bias']
      else:
        raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    else:
        raise FileNotFoundError(f"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).")

    num_labels = list(tensor.shape)[0]
    return num_labels

def get_num_labels_and_task_type_by_adapter(adapter_path):
    adapter_path = Path(adapter_path)

    task_type = None
    if (adapter_path / 'adapter_model.safetensors').exists():
      file_path = adapter_path / 'adapter_model.safetensors'
      with safe_open(file_path, framework="pt") as f:
        if 'base_model.model.classifier.out_proj.bias' in f.keys():
          tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')
        elif 'base_model.model.classifier.bias' in f.keys():
          task_type = 'token_classification'
          tensor = f.get_tensor('base_model.model.classifier.bias')
        else:
          raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    elif (adapter_path / 'adapter_model.bin').exists():
      file_path = adapter_path / 'adapter_model.bin'
      state_dict = torch.load(file_path)
      if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():
        tensor = state_dict['base_model.model.classifier.out_proj.bias']
      elif 'base_model.model.classifier.bias' in f.keys():
        task_type = 'token_classification'
        tensor = state_dict['base_model.model.classifier.bias']
      else:
        raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    else:
        raise FileNotFoundError(f"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).")

    num_labels = list(tensor.shape)[0]
    if task_type != 'token_classification':
      if num_labels > 1:
        task_type = 'classification'
      elif num_labels == 1:
        task_type = 'regression'

    return num_labels, task_type

################################################################################
############################ INFO ##############################################
################################################################################
clear_output(wait=True)
print("Installation finished!")

Installation finished!


In [236]:
#@title **Click the run button to start**
import markdown
import ipywidgets

from ipywidgets import Button, Label, HTML, Layout
from IPython.display import display, clear_output
from functools import partial


# HINT = Fore.RED + "\nNote: At any time you can click the run button again to start restart." + Style.RESET_ALL
HINT = HTML(markdown.markdown("\n\n<font color=red>Note: At any time you can click the run button <img src='https://github.com/westlake-repl/SaProtHub/blob/dev/Figure/run_button.png?raw=true' height='25px' width='25px' align='center'> again to restart.</font>"))


# Jump to next page
def jump(button, next):
  clear_output()
  next()
  display(HINT)


def custom_display(*args, **kwargs):
  clear_output()
  display(*args, **kwargs)
  display(HINT)


######################################################################
#            Backend functions             #
######################################################################
def generate_download_btn(path: str):
  with open(path, "rb") as r:
    res = r.read()

  #FILE
  filename = os.path.basename(path)
  b64 = base64.b64encode(res)
  payload = b64.decode()

  #BUTTONS
  html_buttons = '''<html>
  <head>
  <meta name="viewport" content="width=device-width, initial-scale=1">
  </head>
  <body>
  <a download="{filename}" href="data:text/csv;base64,{payload}" download>
  <button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button>
  </a>
  </body>
  </html>
  '''

  html_button = html_buttons.format(payload=payload,filename=filename)
  return HTML(html_button)


def load_zeroshot_model():
  try:
    zero_shot_model
  except Exception:
    from saprot.model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel
    base_model = "westlake-repl/SaProt_650M_AF2"
    config = {
      "foldseek_path": None,
      "config_path": base_model,
      "load_pretrained": True,
    }

    zero_shot_model = SaprotFoldseekMutationModel(**config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    zero_shot_model.to(device)

  return zero_shot_model


# Zero-shot prediction
def predict_mut(sa_seq, mut_info):
  zero_shot_model = load_zeroshot_model()
  score = zero_shot_model.predict_mut(sa_seq, mut_info)
  return score


# Zero-shot prediction for single-site saturation mutagenesis
def predict_all_mut(sa_seq):
  zero_shot_model = load_zeroshot_model()

  timestamp = datetime.now().strftime("%y%m%d%H%M%S")
  output_path = OUTPUT_HOME / f'{timestamp}_prediction_output.csv'

  mut_dicts = []
  aa_seq = sa_seq[0::2]
  for i in tqdm(range(len(aa_seq)), leave=False, desc=f"Predicting"):
    mut_dict = zero_shot_model.predict_pos_mut(sa_seq, i+1)
    mut_dicts.append(mut_dict)

  mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
  df = pd.DataFrame(mut_list)
  df.to_csv(output_path, index=None)
  return output_path


######################################################################
#          Training or Prediction            #
######################################################################
def train_or_pred():
  question = HTML(markdown.markdown("## Do you want to train your own model or you only want to use existing models to make prediction?"))
  train_btn = Button(description='I want to train my own model', layout=Layout(width='400px', height='30px'))
  pred_btn = Button(description='I want to use existing models to make prediction', layout=Layout(width='400px', height='30px'))

  items = [question, train_btn, pred_btn]

  # Set click events
  train_btn.on_click(partial(jump, next=choose_training_task))
  pred_btn.on_click(partial(jump, next=choose_pred_task))

  display(*items)


######################################################################
#          Choose training task             #
######################################################################
def choose_training_task():
  question = HTML(markdown.markdown("## Please choose the type of your training task:"))
  protein_cls = Button(description='Protein-level classification', layout=Layout(width='500px', height='30px'))
  protein_reg = Button(description='Protein-level regression', layout=Layout(width='500px', height='30px'))
  residue_cls = Button(description='Residue-level classification', layout=Layout(width='500px', height='30px'))
  protein_protein_cls = Button(description='Protein-protein classification', layout=Layout(width='500px', height='30px'))
  protein_protein_reg = Button(description='Protein-protein regression', layout=Layout(width='500px', height='30px'))

  task_type = ipywidgets.Dropdown(
            options=['Protein-level classification', 'Protein-level regression', 'Residue-level classification', "Protein-protein classification", "Protein-protein regression"],
            value='Protein-level classification',
            description='Task type:',
            disabled=False,
          )

  items = [
      question,
      task_type,
      ]

  # Set click events

  display(*items)


######################################################################
#          Choose prediction task            #
######################################################################
def choose_pred_task():
  question = HTML(markdown.markdown("## ColabSaprot supports several prediction tasks, which one would you like to choose?"))
  normal_pred = Button(description='Protein property prediction', layout=Layout(width='500px', height='30px'))
  zeroshot_pred = Button(description='Mutational effect prediction', layout=Layout(width='500px', height='30px'))
  design_pred = Button(description='Protein sequence design', layout=Layout(width='500px', height='30px'))
  repr_pred = Button(description='Obtain protein-level embeddings', layout=Layout(width='500px', height='30px'))
  back_btn = Button(description='Go back', layout=Layout(width='500px', height='30px'))

  items = [
      question,
      normal_pred,
      zeroshot_pred,
      design_pred,
      repr_pred,
      # back_btn
      ]

  # Set click events
  zeroshot_pred.on_click(partial(jump, next=start_mut_pred))
  back_btn.on_click(partial(jump, next=train_or_pred))

  display(*items)


######################################################################
#       Start mutational effect predction         #
######################################################################
def start_mut_pred():
  question = HTML(markdown.markdown("## Please choose the mutation task you want to perform"))
  single_btn = Button(description='Single-site or Multi-site mutagenesis', layout=Layout(width='500px', height='30px'))
  all_btn = Button(description='Single-site saturation mutagenesis', layout=Layout(width='500px', height='30px'))

  items = [
      question,
      single_btn,
      all_btn,
  ]

  # Set click events
  single_btn.on_click(partial(jump, next=single_mut_pred))
  all_btn.on_click(partial(jump, next=saturatopm_mut_pred))

  display(*items)


######################################################################
#      Single-site or Multi-site mutagenesis        #
######################################################################
def single_mut_pred():
  hint = HTML(markdown.markdown("# Single-site or Multi-site mutagenesis\n\n## Please upload the protein structure\n If you only have protein sequence, you could use [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) to predict its structure and upload it here."))

  chain_hint = HTML(markdown.markdown("Chain (to be extracted from the structure):"))
  input_chain = ipywidgets.Text(value="A",placeholder=f'Enter the name of chain here', layout=Layout(width='500px', height='30px'))
  upload_hint = HTML(markdown.markdown("Upload the protein structure:"))
  upload_btn = ipywidgets.FileUpload(accept='',multiple=False)

  items = [
      hint,
      chain_hint,
      input_chain,
      upload_hint,
      upload_btn,
  ]

  # Set click events
  def on_upload_file(change):
    upload_path = Path(STRUCTURE_HOME)
    upload_path.mkdir(parents=True, exist_ok=True)
    basepath = Path().resolve()

    # Write to specific path
    name = list(upload_btn.value.keys())[0]
    content = upload_btn.value[name]["content"]
    save_path = upload_path / name
    with open(save_path, "wb") as wb:
      wb.write(content)

    chain = input_chain.value
    print(chain)
    protein_list = [(name, chain)]
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True, verbose=False)
    seqs = mprs.run()

    sa_seq = seqs[0].split("\t")[-1]
    aa_seq = sa_seq[0::2]
    struc_seq = sa_seq[1::2].replace("#", "\#")

    seq_info = HTML(markdown.markdown(f"**{name}**\n\n**3Di sequence:**\n\n{struc_seq}\n\n**Amino acid sequence:**\n\n{aa_seq}"))

    # Mutation information box
    mut_hint = HTML(markdown.markdown(
        "**Please input the mutation information:**\n\n"
        "**For single-site mutation, e.g. M1E means mutating the amino acid M to E at first position. For multi-site mutation, you are expected to separate each position by ':', e.g. M1E:P2V.**"
        )
    )
    input_mut = ipywidgets.Text(placeholder='Enter Mutation Information here', layout=Layout(width='1000px', height='30px'))
    submit_btn = Button(description='Calculate mutation score', layout=Layout(width='250px', height='30px'))


    new_items = items + [seq_info, mut_hint, input_mut, submit_btn]

    # Click to calculate mutation score
    def calc_mut_score(button):
      mut_info = input_mut.value
      print(f"Predict the mutation score for {mut_info}...")
      score = predict_mut(sa_seq, mut_info)
      score_hint = HTML(markdown.markdown(f"The score for {mut_info} is <font color=red>{score.item()}</font>. A positive score means the mutation is better than the wild type from evolution perspective (the larger the better)."))
      new_items = items + [seq_info, mut_hint, input_mut, submit_btn, score_hint]
      custom_display(*new_items)

    submit_btn.on_click(calc_mut_score)
    custom_display(*new_items)

  upload_btn.observe(on_upload_file, names='value')

  display(*items)


######################################################################
#       Single-site saturation mutagenesis          #
######################################################################
def saturatopm_mut_pred():
  hint = HTML(markdown.markdown("# Single-site saturation mutagenesis\n\n## Please upload the protein structure\n If you only have protein sequence, you could use [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) to predict its structure and upload it here."))

  chain_hint = HTML(markdown.markdown("Chain (to be extracted from the structure):"))
  input_chain = ipywidgets.Text(value="A",placeholder=f'Enter the name of chain here', layout=Layout(width='500px', height='30px'))
  upload_hint = HTML(markdown.markdown("Upload the protein structure:"))
  upload_btn = ipywidgets.FileUpload(accept='',multiple=False)

  items = [
      hint,
      chain_hint,
      input_chain,
      upload_hint,
      upload_btn,
  ]

  # Set click events
  def on_upload_file(change):
    upload_path = Path(STRUCTURE_HOME)
    upload_path.mkdir(parents=True, exist_ok=True)
    basepath = Path().resolve()

    # Write to specific path
    name = list(upload_btn.value.keys())[0]
    content = upload_btn.value[name]["content"]
    save_path = upload_path / name
    with open(save_path, "wb") as wb:
      wb.write(content)

    chain = input_chain.value
    print(chain)
    protein_list = [(name, chain)]
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True, verbose=False)
    seqs = mprs.run()

    sa_seq = seqs[0].split("\t")[-1]
    aa_seq = sa_seq[0::2]
    struc_seq = sa_seq[1::2].replace("#", "\#")

    seq_info = HTML(markdown.markdown(f"**{name}**\n\n**3Di sequence:**\n\n{struc_seq}\n\n**Amino acid sequence:**\n\n{aa_seq}"))
    submit_btn = Button(description='Calculate mutation score for all single-site mutations', layout=Layout(width='500px', height='30px'))


    new_items = items + [seq_info, submit_btn]

    # Click to calculate mutation score
    def calc_mut_score(button):
      print(f"Predicting mutation scores for all single-site mutations...")
      save_path = predict_all_mut(sa_seq)
      score_hint = HTML(markdown.markdown(f"The result has been saved to {save_path}. You can click to download the file.\n\n"
      "**For mutation score, A positive score means the mutation is better than the wild type from evolution perspective (the larger the better).**"))
      download_btn = generate_download_btn(save_path)

      new_items = items + [seq_info, submit_btn, score_hint, download_btn]
      custom_display(*new_items)

    submit_btn.on_click(calc_mut_score)
    custom_display(*new_items)

  upload_btn.observe(on_upload_file, names='value')

  display(*items)


train_or_pred()
display(HINT)

HTML(value='<h2>Do you want to train your own model or you only want to use existing models to make prediction…

Button(description='I want to train my own model', layout=Layout(height='30px', width='400px'), style=ButtonSt…

Button(description='I want to use existing models to make prediction', layout=Layout(height='30px', width='400…

HTML(value="<p><font color=red>Note: At any time you can click the run button <img src='https://github.com/wes…

In [289]:
title = HTML(markdown.markdown("## Please finish the setting of your training task"))
WIDTH = "400px"
HEIGHT= "30px"

task_hint = HTML(markdown.markdown("### Task setting:"))
task_type = ipywidgets.Dropdown(
          options=['Protein-level Classification', 'Protein-level Regression', 'Residue-level Classification', "Protein-protein Classification", "Protein-protein Regression"],
          value='Protein-level Classification',
          description='Task type:',
          disabled=False,
          layout=Layout(width=WIDTH, height=HEIGHT)
        )

num_label = ipywidgets.BoundedIntText(
      value=2,
      min=2,
      max=100,
      step=1,
      description='Number of classes:',
      disabled=False,
      style={'description_width': 'initial'},
      layout=Layout(width=WIDTH, height=HEIGHT)
    )

model_hint = HTML(markdown.markdown("### Model setting:"))
model_type = ipywidgets.Dropdown(
          options=['Official SaProt (35M)', "Official SaProt (650M)", "Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub"],
          value='Official SaProt (35M)',
          description='Base model:',
          disabled=False,
          layout=Layout(width=WIDTH, height=HEIGHT)
        )

dataset_hint = HTML(markdown.markdown("### Dataset setting:"))

hyperparameter_hint = HTML(markdown.markdown("### Training hyper-parameters:"))
batch_size = ipywidgets.Dropdown(
          options=["Adaptive", "1", "2", "4", "8", "16", "32", "64", "128", "256"],
          value='Adaptive',
          description='Batch size:',
          disabled=False,
          layout=Layout(width=WIDTH, height=HEIGHT)
          )

epoch = ipywidgets.BoundedIntText(
      value=5,
      min=1,
      max=100,
      step=1,
      description='Epoch:',
      disabled=False,
      style={'description_width': 'initial'},
      layout=Layout(width=WIDTH, height=HEIGHT)
    )

lr = ipywidgets.FloatText(
    value=5e-4,
    description='Learning rate:',
    disabled=False,
    style={'description_width': 'initial'},
    layout=Layout(width=WIDTH, height=HEIGHT)
)

start_btn = Button(description='Start training', layout=Layout(width='400px', height='30px'), button_style="info")


items = [
    title,
    task_hint,
    task_type,
    num_label,
    model_hint,
    model_type,
    dataset_hint,
    hyperparameter_hint,
    batch_size,
    epoch,
    lr,
    start_btn
    ]

# Set click events
def change_task_type(change):
  now_type = change["new"]
  if "Classification" in now_type:
    num_label.layout.visibility = "visible"
  else:
    num_label.layout.visibility = "hidden"

def change_model_type(change):
  model_type_value = change["new"]
  task_type_value = task_type_dict[task_type.value]
  if model_type_value == "Trained by yourself on ColabSaprot":
    adapter_combobox = select_adapter_from(task_type_value, use_model_from=model_type_value)
    adapter_combobox.layout.width = WIDTH
    adapter_combobox.description = "Select your local model"
    adapter_combobox.style = {'description_width': 'initial'}
    new_items = items[:6] + [adapter_combobox] + items[6:]
    custom_display(*new_items)

  elif model_type_value == "Shared by peers on SaprotHub":
    adapter_combobox = select_adapter_from(task_type_value, use_model_from=model_type_value)
    adapter_combobox.layout.width = WIDTH
    adapter_combobox.placeholder = "Enter HuggingFace model id"
    new_items = items[:6] + [adapter_combobox] + items[6:]
    custom_display(*new_items)

  else:
    custom_display(*items)

def start_training(button):
  ######################################################################
  #              Start training              #
  ######################################################################
  # training config
  GPU_batch_size = 0
  accumulate_grad_batches = 0
  num_workers = 2
  seed = 20000812

  # lora config
  r = 8
  lora_dropout = 0.0
  lora_alpha = 16

  # dataset config
  val_check_interval=0.5
  limit_train_batches=1.0
  limit_val_batches=1.0
  limit_test_batches=1.0


  mask_struc_ratio=None


task_type.observe(change_task_type, names='value')
model_type.observe(change_model_type, names='value')
start_btn.on_click(start_training)


# Set click events

display(*items)

HTML(value='<h2>Please finish the setting of your training task</h2>')

HTML(value='<h3>Task setting:</h3>')

Dropdown(description='Task type:', index=2, layout=Layout(height='30px', width='400px'), options=('Protein-lev…

BoundedIntText(value=2, description='Number of classes:', layout=Layout(height='30px', visibility='visible', w…

HTML(value='<h3>Model setting:</h3>')

Dropdown(description='Base model:', index=1, layout=Layout(height='30px', width='400px'), options=('Official S…

HTML(value='<h3>Dataset setting:</h3>')

HTML(value='<h3>Training hyper-parameters:</h3>')

Dropdown(description='Batch size:', layout=Layout(height='30px', width='400px'), options=('Adaptive', '1', '2'…

BoundedIntText(value=5, description='Epoch:', layout=Layout(height='30px', width='400px'), min=1, style=Descri…

FloatText(value=0.0005, description='Learning rate:', layout=Layout(height='30px', width='400px'), style=Descr…

Button(button_style='info', description='Start training', layout=Layout(height='30px', width='400px'), style=B…

HTML(value="<p><font color=red>Note: At any time you can click the run button <img src='https://github.com/wes…

In [257]:
task_name = "demo" # @param {type:"string"}
task_type = "Residue-level Classification" # @param ["Protein-level Classification", "Protein-level Regression", "Residue-level Classification", "Protein-protein Classification", "Protein-protein Regression"]
original_task_type = task_type
task_type = task_type_dict[task_type]

base_model = "Trained by yourself on ColabSaprot" # @param ["Official pretrained SaProt (35M)", "Official pretrained SaProt (650M)", "Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub", "Saved in your local computer"]

# continue learning
if base_model in ["Trained by yourself on ColabSaprot", "Shared by peers on SaprotHub", "Saved in your local computer"]:
  continue_learning = True
  adapter_combobox = select_adapter_from(task_type, use_model_from=base_model)
else:
  continue_learning = False

print(task_type, type(task_type))
print(adapter_combobox)

[34mLocal Model (classification):[0m


Dropdown(layout=Layout(width='500px'), options=(PosixPath('Local/Model-demo-35M'),), value=None)

classification <class 'str'>
Dropdown(layout=Layout(width='500px'), options=(PosixPath('Local/Model-demo-35M'),), value=None)


In [261]:
task_type_dict

{'Protein-level Classification': 'classification',
 'Residue-level Classification': 'token_classification',
 'Protein-level Regression': 'regression',
 'Protein-protein Classification': 'pair_classification',
 'Protein-protein Regression': 'pair_regression'}