# **ColabProTrek**

<a href="https://www.biorxiv.org/content/10.1101/2024.05.30.596740v1"><img src="https://img.shields.io/badge/Paper-bioRxiv-green" style="max-width: 100%;"></a>
<a href="https://huggingface.co/ProTrekHub"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-red?label=ProTrekHub" style="max-width: 100%;"></a>
<a href="https://huggingface.co/spaces/westlake-repl/Demo_ProTrek_650M_UniRef50"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-red?label=Demo" style="max-width: 100%;"></a>
<a href="https://huggingface.co/westlake-repl/ProTrek_650M_UniRef50"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-yellow?label=Model" style="max-width: 100%;"></a>
<a href="https://cbirt.net/charting-the-protein-universe-with-protreks-tri-modal-contrastive-learning/" alt="blog"><img src="https://img.shields.io/badge/Blog-Medium-purple" /></a>


This is **ColabProTrek**, the Colab version of [ProTrek](https://github.com/westlake-repl/ProTrek) [(paper)](https://www.biorxiv.org/content/10.1101/2024.05.30.596740v1)

ProTrek is also a retrieval model for sequence-structure-function pairwise searches. [Try](https://huggingface.co/spaces/westlake-repl/Demo_ProTrek_650M_UniRef50)

**ColabProTrek** is a platform where **Protein Language Models(PLMs)** are more accessible and user-friendly for biologists, enabling effortless model training and sharing within the scientific community.

ColabProTrek is a member of the OPMC family. You might also be interested in its sibling, [ColabSaprot](https://colab.research.google.com/github/westlake-repl/SaprotHub/blob/main/colab/SaprotHub_v2.ipynb), [ColabProtT5](https://colab.research.google.com/github/westlake-repl/SaprotHub/blob/main/colab/ColabProtT5.ipynb)

<!-- If you find our ColabProTrek useful for your research, please also consider citing our [OPMC literature](https://www.biorxiv.org/content/10.1101/2024.05.24.595648v4). We have invested significant effort in developing these platforms. -->





## ColabProTrek

| Function                             | Tutorial                                                     | Video                                                        |
| ------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| <a href="#train">Train your model</a>                     | [How to train your model](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#21-Train-your-model) | -  |
| <a href="#prediction">Classification/Regression Prediction</a> | [How to use model for classification/regression prediction](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#31-Classification-Regression-Prediction) |

<br>

<font color=red>**To view the content, please click on the first option in the left sidebar.**</font>



# **1: Installation**

## ⚠️SWITCH YOUR RUNTIME TYPE TO GPU
Before installing ProTrek, please **<font color=red>SWITCH YOUR RUNTIME TYPE TO GPU!!!</font>**

> 📍Please check this [page](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#11-Switch-your-runtime-type-to-GPU) to learn **how to switch your runtime type to GPU**.

## ⚠️Maximum Runtime and Idle Timeout

To ensure your program finishes properly, please avoid letting your computer go to **<font color=red>sleep</font>** or remain **<font color=red>idle</font>** for long periods.

Please be aware of **<font color=red>the maximum runtime</font>**, as your program may be automatically terminated when this limit is reached.

| Plan            | Maximum Runtime | Idle Timeout | Additional Features                        |
|-----------------|------------------|--------------|--------------------------------------------|
| **Free**        | 12 hours         | Yes          | -                                          |
| **Colab Pro**   | Based on availability and usage patterns | Yes          | Increased compute availability             |
| **Pay As You Go**| Based on availability and usage patterns | Yes          | Increased compute availability             |
| **Colab Pro+**  | Up to 24 hours   | No           | Background execution, continuous code execution |


In [None]:
#@title **1.1: ▶️ Click the run button to install ProTrek**

#@markdown (Please waiting for 2-8 minutes to install...)
################################################################################
########################### install saprot #####################################
################################################################################
# %load_ext autoreload
# %autoreload 2

import os
# Check whether the server is local or from google cloud
root_dir = os.getcwd()

from google.colab import output
output.enable_custom_widget_manager()

try:
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")
  import saprot
  print("ProTrek is installed successfully!")
  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")

except ImportError:
  print("Installing ProTrek...")
  os.system(f"rm -rf {root_dir}/SaprotHub")
  # !rm -rf /content/SaprotHub/

  !echo "Cloning into 'ProTrekHub'..."
  !git clone https://github.com/westlake-repl/SaprotHub.git > /dev/null 2>&1
  !pip install huggingface_hub=0.23.2 > /dev/null 2>&1

  # !pip install /content/SaprotHub/saprot-0.4.7-py3-none-any.whl
  os.system(f"pip install -r {root_dir}/SaprotHub/requirements.txt")
  # !pip install -r /content/SaprotHub/requirements.txt

  os.system(f"pip install {root_dir}/SaprotHub")


  os.system(f"mkdir -p {root_dir}/SaprotHub/LMDB")
  os.system(f"mkdir -p {root_dir}/SaprotHub/bin")
  os.system(f"mkdir -p {root_dir}/SaprotHub/output")
  os.system(f"mkdir -p {root_dir}/SaprotHub/datasets")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/token_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_classification/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/adapters/pair_regression/Local")
  os.system(f"mkdir -p {root_dir}/SaprotHub/structures")
  # !mkdir -p /content/SaprotHub/LMDB
  # !mkdir -p /content/SaprotHub/bin
  # !mkdir -p /content/SaprotHub/output
  # !mkdir -p /content/SaprotHub/datasets
  # !mkdir -p /content/SaprotHub/adapters/classification/Local
  # !mkdir -p /content/SaprotHub/adapters/regression/Local
  # !mkdir -p /content/SaprotHub/adapters/token_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_classification/Local
  # !mkdir -p /content/SaprotHub/adapters/pair_regression/Local
  # !mkdir -p /content/SaprotHub/structures

  # !pip install gdown==v4.6.3 --force-reinstall --quiet
  # os.system(
  #   f"wget 'https://drive.usercontent.google.com/download?id=1B_9t3n_nlj8Y3Kpc_mMjtMdY0OPYa7Re&export=download&authuser=0' -O {root_dir}/SaprotHub/bin/foldseek"
  # )

  os.system(f"chmod +x {root_dir}/SaprotHub/bin/*")
  # !chmod +x /content/SaprotHub/bin/foldseek
  import sys
  sys.path.append(f"{root_dir}/SaprotHub")


  # IMPORTANT!!!! Used to fix the error caused by the mismatch of the versions of third-party libraries!!
  import matplotlib.pyplot as plt
  plt.figure(figsize=(1, 1))
  plt.plot([], [], marker='o', linestyle='-', color='b')
  plt.show()

################################################################################
################################################################################
################################## global ######################################
################################################################################
################################################################################

# IMPORTANT!!!! Used to fix the error caused by the mismatch of the versions of third-party libraries!!
import sys
keys=[]
for k in sys.modules.keys():
  for sub_str in ["numpy"]:
    if sub_str in k:
      keys.append(k)
for k in keys:
  del sys.modules[k]

from transformers import AutoTokenizer, EsmForProteinFolding, EsmTokenizer
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
import ipywidgets
import numpy as np
import pandas as pd
import torch
import lmdb
import base64
import copy
import os
import json
import zipfile
import yaml
import argparse
import pprint
import subprocess
import py3Dmol
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

from loguru import logger
from easydict import EasyDict
from colorama import init, Fore, Back, Style
from IPython.display import clear_output
from saprot.utils.mpr import MultipleProcessRunnerSimplifier
from huggingface_hub import snapshot_download
from ipywidgets import HTML
from IPython.display import display
from google.colab import widgets
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from google.colab import files
from string import ascii_uppercase,ascii_lowercase
from saprot.data.parse import get_chain_ids
import torch.nn.functional as F

DATASET_HOME = Path(f'{root_dir}/SaprotHub/datasets')
ADAPTER_HOME = Path(f'{root_dir}/SaprotHub/adapters')
STRUCTURE_HOME = Path(f"{root_dir}/SaprotHub/structures")
LMDB_HOME = Path(f'{root_dir}/SaprotHub/LMDB')
OUTPUT_HOME = Path(f'{root_dir}/SaprotHub/output')
UPLOAD_FILE_HOME = Path(f'{root_dir}/SaprotHub/upload_files')
FOLDSEEK_PATH = Path(f"{root_dir}/SaprotHub/bin/foldseek")
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"

data_type_list = ["Single AA Sequence",
                  "Single SA Sequence",
                  "Single UniProt ID",
                  "Single PDB/CIF Structure",
                  "Multiple AA Sequences",
                  "Multiple SA Sequences",
                  "Multiple UniProt IDs",
                  "Multiple PDB/CIF Structures",
                  "SaprotHub Dataset",
                  "A pair of AA Sequences",
                  "A pair of SA Sequences",
                  "A pair of UniProt IDs",
                  "A pair of PDB/CIF Structures",
                  "Multiple pairs of AA Sequences",
                  "Multiple pairs of SA Sequences",
                  "Multiple pairs of UniProt IDs",
                  "Multiple pairs of PDB/CIF Structures",]

data_type_list_single = [
    "Single AA Sequence",
    "Single SA Sequence",
    "Single UniProt ID",
    "Single PDB/CIF Structure",
    "A pair of AA Sequences",
    "A pair of SA Sequences",
    "A pair of UniProt IDs",
    "A pair of PDB/CIF Structures",]

data_type_list_multiple = [
    "Multiple AA Sequences",
    "Multiple SA Sequences",
    "Multiple UniProt IDs",
    "Multiple PDB/CIF Structures",
    "Multiple pairs of AA Sequences",
    "Multiple pairs of SA Sequences",
    "Multiple pairs of UniProt IDs",
    "Multiple pairs of PDB/CIF Structures",]

task_type_dict = {
  "Protein-level Classification": "classification",
  "Residue-level Classification" : "token_classification",
  "Protein-level Regression" : "regression",
  "Protein-protein Classification": "pair_classification",
  "Protein-protein Regression": "pair_regression",
}
model_type_dict = {
  "classification" : "saprot/saprot_classification_model",
  "token_classification" : "saprot/saprot_token_classification_model",
  "regression" : "saprot/saprot_regression_model",
  "pair_classification" : "saprot/saprot_pair_classification_model",
  "pair_regression" : "saprot/saprot_pair_regression_model",
}
dataset_type_dict = {
  "classification": "saprot/saprot_classification_dataset",
  "token_classification" : "saprot/saprot_token_classification_dataset",
  "regression": "saprot/saprot_regression_dataset",
  "pair_classification" : "saprot/saprot_pair_classification_dataset",
  "pair_regression" : "saprot/saprot_pair_regression_dataset",
}
training_data_type_dict = {
  "Single AA Sequence": "AA",
  "Single SA Sequence": "SA",
  "Single UniProt ID": "SA",
  "Single PDB/CIF Structure": "SA",
  "Multiple AA Sequences": "AA",
  "Multiple SA Sequences": "SA",
  "Multiple UniProt IDs": "SA",
  "Multiple PDB/CIF Structures": "SA",
  "SaprotHub Dataset": "SA",
  "A pair of AA Sequences": "AA",
  "A pair of SA Sequences": "SA",
  "A pair of UniProt IDs": "SA",
  "A pair of PDB/CIF Structures": "SA",
  "Multiple pairs of AA Sequences": "AA",
  "Multiple pairs of SA Sequences": "SA",
  "Multiple pairs of UniProt IDs": "SA",
  "Multiple pairs of PDB/CIF Structures": "SA",
}


class font:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'

    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    RESET = '\033[0m'


################################################################################
############################### adapters #######################################
################################################################################
def get_adapters_list(task_type=None):

    adapters_list = []

    if task_type:
      for file_path in (ADAPTER_HOME / task_type).glob('**/adapter_config.json'):
        adapters_list.append(file_path.parent)
    else:
      for file_path in ADAPTER_HOME.glob('**/adapter_config.json'):
        adapters_list.append(file_path.parent)

    return adapters_list

def adapters_text(adapters_list):
  input = ipywidgets.Text(
    value=None,
    placeholder='Enter ProTrekHub Model ID',
    # description='Selected:',
    disabled=False)
  input.layout.width = '500px'
  display(input)

  return input

def adapters_dropdown(adapters_list):
  dropdown = ipywidgets.Dropdown(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Select a Local Model here',
    # description='Selected:',
    disabled=False)
  dropdown.layout.width = '500px'
  display(dropdown)

  return dropdown

def adapters_combobox(adapters_list):
  combobox = ipywidgets.Combobox(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Enter ProTrekHub Model repository id or select a Local Model here',
    # description='Selected:',
    disabled=False)
  combobox.layout.width = '500px'
  display(combobox)

  return combobox

def adapters_selectmultiple(adapters_list):
  selectmulitiple = ipywidgets.SelectMultiple(
  options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
  value=[],
  #rows=10,
  placeholder='Select multiple models',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(selectmulitiple)

  return selectmulitiple

def adapters_textmultiple(adapters_list):
  textmultiple = ipywidgets.Text(
  value=None,
  placeholder='Enter multiple ProTrekHub Model IDs, separated by commas.',
  # description='Fruits',
  disabled=False,
  layout={'width': '500px'})
  display(textmultiple)

  return textmultiple


def select_adapter_from(task_type, use_model_from):
  adapters_list = get_adapters_list(task_type)

  if use_model_from == 'Trained by yourself on ColabProTrek':
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    return adapters_dropdown(adapters_list)

  elif use_model_from == 'Shared by peers on ProTrekHub':
    print(Fore.BLUE+"ProTrekHub Model:"+Style.RESET_ALL)
    return adapters_text(adapters_list)

  elif use_model_from == "Saved in your local computer":
    print(Fore.BLUE+"Click the button to upload the \"Model-<task_name>-<model_size>.zip\" file of your Model:"+Style.RESET_ALL)
    # 1. upload model.zip
    adapter_upload_path = ADAPTER_HOME / task_type / "Local"
    adapter_zip_path = upload_file(adapter_upload_path)
    adapter_path = adapter_upload_path / adapter_zip_path.stem
    # 2. unzip model.zip
    with zipfile.ZipFile(adapter_zip_path, 'r') as zip_ref:
        zip_ref.extractall(adapter_path)
    os.remove(adapter_zip_path)
    # 3. check adapter_config.json
    adapter_config_path = adapter_path / "adapter_config.json"
    assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"

    return EasyDict({"value":  f"Local/{adapter_zip_path.stem}"})

  elif use_model_from == "Multi-models on ColabProTrek":
    # 1. select the list of adapters
    print(Fore.BLUE+f"Local Model ({task_type}):"+Style.RESET_ALL)
    print(Fore.BLUE+f"Multiple values can be selected with \"shift\" and/or \"ctrl\" (or \"command\") pressed and mouse clicks or arrow keys."+Style.RESET_ALL)
    return adapters_selectmultiple(adapters_list)

  elif use_model_from == "Multi-models on ProTrekHub":
    # 1. enter the list of adapters
    print(Fore.BLUE+f"ProTrekHub Model IDs, separated by commas ({task_type}):"+Style.RESET_ALL)
    return adapters_textmultiple(adapters_list)



################################################################################
########################### download dataset ###################################
################################################################################
def download_dataset(task_name):
  import gdown
  import tarfile

  filepath = LMDB_HOME / f"{task_name}.tar.gz"
  download_links = {
    "ClinVar" : "https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo",
    "DeepLoc_cls2" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "DeepLoc_cls10" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "EC" : "https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ",
    "GO_BP" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_CC" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_MF" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "HumanPPI" : "https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX",
    "MetalIonBinding" : "https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3",
    "ProteinGym" : "https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD",
    "Thermostability" : "https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz",
  }

  try:
    gdown.download(download_links[task_name], str(filepath), quiet=False)
    with tarfile.open(filepath, 'r:gz') as tar:
      tar.extractall(path=str(LMDB_HOME))
      print(f"Extracted: {filepath}")
  except Exception as e:
    raise RuntimeError("The dataset has not prepared.")

################################################################################
############################# upload file ######################################
################################################################################
def upload_file(upload_path):
  import shutil
  import os
  from pathlib import Path
  import sys

  upload_path = Path(upload_path)
  upload_path.mkdir(parents=True, exist_ok=True)
  basepath = Path().resolve()
  try:
    uploaded = files.upload()
    filenames = []
    for filename in uploaded.keys():
      filenames.append(filename)
      shutil.move(basepath / filename, upload_path / filename)
    if len(filenames) == 0:
      logger.info("The uploading process has been interrupted by the user.")
      raise RuntimeError("The uploading process has been interrupted by the user.")
  except Exception as e:
    logger.error("Upload file fail! Please click the button to run again.")
    raise(e)

  return upload_path / filenames[0]

################################################################################
############################ upload dataset ####################################
################################################################################

def read_csv_dataset(uploaded_csv_path):
  df = pd.read_csv(uploaded_csv_path)
  df.columns = df.columns.str.lower()
  return df

def check_column_label_and_stage(csv_dataset_path):
  df = read_csv_dataset(csv_dataset_path)
  for column in df.columns:
    if 'sequence' in column:
        df[column] = df[column].apply(lambda x: ''.join([c for c in x if not (c.islower() or c == '#')]))
  assert {'label', 'stage'}.issubset(df.columns), f"Make sure your CSV dataset includes both `label` and `stage` columns!\nCurrent columns: {df.columns}"
  column_values = set(df['stage'].unique())
  assert all(value in column_values for value in ['train', 'valid', 'test']), f"Ensure your dataset includes samples for all three stages: `train`, `valid` and `test`.\nCurrent columns: {df.columns}"
  output_file = os.path.join(os.path.dirname(csv_dataset_path), 'cleaned_dataset.csv')
  df.to_csv(output_file, index=False)

  return output_file

def get_data_type(csv_dataset_path):
  # AA, SA, Pair AA, Pair SA
  df = read_csv_dataset(csv_dataset_path)
  df = df.rename(columns={
    "protein_1": "sequence_1",
    "protein_2": "sequence_2",
    "protein": "sequence"})

  # AA, SA
  if 'sequence' in df.columns:
    second_token = df.loc[0, 'sequence'][1]
    if second_token in aa_set:
      return "Multiple AA Sequences"
    elif second_token in foldseek_struc_vocab:
      return "Multiple SA Sequences"
    else:
      raise RuntimeError(f"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.")

  # Pair AA, Pair SA
  elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:
    second_token = df.loc[0, 'sequence_1'][1]
    if second_token in aa_set:
      return "Multiple pairs of AA Sequences"
    elif second_token in foldseek_struc_vocab:
      return "Multiple pairs of SA Sequences"
    else:
      raise RuntimeError(f"The sequence in the dataset({csv_dataset_path}) are neither SA Sequences nor AA Sequences. Please check carefully.")

  else:
      raise RuntimeError(f"The data type of the dataset({csv_dataset_path}) should be one of the following types: Multiple AA Sequences, Multiple SA Sequences, Multiple pairs of AA Sequences, Multiple pairs of SA Sequences")

def check_task_type_and_data_type(original_task_type, data_type):
  if "Protein-protein" in original_task_type:
    assert data_type == "SaprotHub Dataset" or "pair" in data_type, f"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please use Pair Sequence Datset for {original_task_type} task!"
  else:
    assert "pair" not in data_type, f"The current `data_type`({data_type}) is incompatible with the current `task_type`({original_task_type}). Please avoid using the Pair Sequence Dataset({data_type}) for the {original_task_type} task!"

def input_raw_data_by_data_type(data_type):
  print(Fore.BLUE+"Dataset: "+Style.RESET_ALL, end='')

  # 0-2. 0. Single AA Sequence, 1. Single SA Sequence, 2. Single UniProt ID
  if data_type in data_type_list[:3]:
    input_seq = ipywidgets.Text(
      value=None,
      placeholder=f'Enter {data_type} here',
      disabled=False)
    input_seq.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_seq)
    return input_seq

  # 3. Single PDB/CIF Structure
  elif data_type == 'Single PDB/CIF Structure':
    print("Please provide the structure type, chain and your structure file.")

    dropdown_type = ipywidgets.Dropdown(
      value="AF2",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type.layout.width = '500px'
    print(Fore.BLUE+"Structure type:"+Style.RESET_ALL)
    display(dropdown_type)

    input_chain = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain here',
      disabled=False)
    input_chain.layout.width = '500px'
    print(Fore.BLUE+"Chain:"+Style.RESET_ALL)
    display(input_chain)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path = upload_file(STRUCTURE_HOME)
    return pdb_file_path.stem, dropdown_type, input_chain

  # 4-7 & 13-16. Multiple Sequences
  elif data_type in data_type_list_multiple:
    print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
    uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
    print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
    print("="*100)

    if data_type in ['Multiple PDB/CIF Structures', 'Multiple pairs of PDB/CIF Structures']:
      # upload and unzip PDB files
      print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
      pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
      if pdb_zip_path.suffix != ".zip":
        logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
        raise RuntimeError("The data type does not match.")
      print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
      print("="*100)

      import zipfile
      with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
        zip_ref.extractall(STRUCTURE_HOME)

    return uploaded_csv_path

  # 8. SaprotHub Dataset
  elif data_type == "SaprotHub Dataset":
    input_repo_id = ipywidgets.Text(
      value=None,
      placeholder=f'Copy and paste the SaprotHub Dataset ID here',
      disabled=False)
    input_repo_id.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_repo_id)
    return input_repo_id

  # 9-11. A pair of seq
  elif data_type in ["A pair of AA Sequences", "A pair of SA Sequences", "A pair of UniProt IDs"]:
    print()

    seq_type = data_type[len("A pair of "):-1]

    input_seq1 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 1 here',
      disabled=False)
    input_seq1.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 1:"+Style.RESET_ALL)
    display(input_seq1)

    input_seq2 = ipywidgets.Text(
      value=None,
      placeholder=f'Enter the {seq_type} of Sequence 2 here',
      disabled=False)
    input_seq2.layout.width = '500px'
    print(Fore.BLUE+f"Sequence 2:"+Style.RESET_ALL)
    display(input_seq2)

    return (input_seq1, input_seq2)

  # 12. Pair Single PDB/CIF Structure
  elif data_type == 'A pair of PDB/CIF Structures':
    print("Please provide the structure type, chain and your structure file.")

    dropdown_type1 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type1.layout.width = '500px'
    print(Fore.BLUE+"The first structure type:"+Style.RESET_ALL)
    display(dropdown_type1)

    input_chain1 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the first structure here',
      disabled=False)
    input_chain1.layout.width = '500px'
    print(Fore.BLUE+"Chain of the first structure:"+Style.RESET_ALL)
    display(input_chain1)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path1 = upload_file(STRUCTURE_HOME)


    dropdown_type2 = ipywidgets.Dropdown(
      value="PDB",
      options=["PDB", "AF2"],
      disabled=False)
    dropdown_type2.layout.width = '500px'
    print(Fore.BLUE+"The second structure type:"+Style.RESET_ALL)
    display(dropdown_type2)

    input_chain2 = ipywidgets.Text(
      value="A",
      placeholder=f'Enter the name of chain of the second structure here',
      disabled=False)
    input_chain2.layout.width = '500px'
    print(Fore.BLUE+"Chain of the second structure:"+Style.RESET_ALL)
    display(input_chain2)

    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path2 = upload_file(STRUCTURE_HOME)
    return (pdb_file_path1.stem, dropdown_type1, input_chain1, pdb_file_path2.stem, dropdown_type2, input_chain2)

def get_SA_sequence_by_data_type(data_type, raw_data):

  # Multiple sequences
  # raw_data = upload_files/xxx.csv

  # 8. SaprotHub Dataset
  if data_type == "SaprotHub Dataset":
    input_repo_id = raw_data
    REPO_ID = input_repo_id.value

    if REPO_ID.startswith('/'):
      return Path(REPO_ID)

    snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=DATASET_HOME / REPO_ID)
    csv_dataset_path = DATASET_HOME / REPO_ID / 'dataset.csv'
    assert csv_dataset_path.exists(), f"Can't find {csv_dataset_path}"
    protein_df = read_csv_dataset(csv_dataset_path)

    data_type = get_data_type(csv_dataset_path)

    return get_SA_sequence_by_data_type(data_type, csv_dataset_path)

    # # AA, SA
    # if data_type == "Multiple AA Sequences":
    #   for index, value in protein_df['sequence'].items():
    #     sa_seq = ''
    #     for aa in value:
    #       sa_seq += aa + '#'
    #     protein_df.at[index, 'sequence'] = sa_seq

    # # Pair AA, Pair SA
    # elif data_type in ["Multiple pairs of AA Sequences", "Multiple pairs of SA Sequences"]:
    #   for i in ['1', '2']:
    #     if data_type == "Multiple pairs of AA Sequences":
    #       for index, value in protein_df[f'sequence_{i}'].items():
    #         sa_seq = ''
    #         for aa in value:
    #           sa_seq += aa + '#'
    #         protein_df.at[index, f'sequence_{i}'] = sa_seq

    #     protein_df[f'name_{i}'] = f'name_{i}'
    #     protein_df[f'chain_{i}'] = 'A'

    # protein_df.to_csv(csv_dataset_path, index=None)

    # return csv_dataset_path

  elif data_type in data_type_list_multiple:
    uploaded_csv_path = raw_data
    csv_dataset_path = DATASET_HOME / uploaded_csv_path.name
    protein_df = read_csv_dataset(uploaded_csv_path)
    protein_df = protein_df.rename(columns={
    "protein_1": "sequence_1",
    "protein_2": "sequence_2",
    "protein": "sequence"})

    if 'pair' in data_type:
      assert {'sequence_1', 'sequence_2'}.issubset(protein_df.columns), f"The CSV dataset ({uploaded_csv_path}) must contain `sequence_1` and `sequence_2` columns. \n Current columns:{protein_df.columns}"
    else:
      assert 'sequence' in protein_df.columns, f"The CSV Dataset({uploaded_csv_path}) must contain a `sequence` column. \n Current columns:{protein_df.columns}"

    # 4. Multiple AA Sequences
    if data_type == 'Multiple AA Sequences':
      for index, value in protein_df['sequence'].items():
        sa_seq = ''
        for aa in value:
          sa_seq += aa + '#'
        protein_df.at[index, 'sequence'] = sa_seq

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 5. Multiple SA Sequences
    elif data_type == 'Multiple SA Sequences':
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 6. Multiple UniProt IDs
    elif data_type == 'Multiple UniProt IDs':
      protein_list = protein_df.loc[:, 'sequence'].tolist()
      uniprot2pdb(protein_list)
      protein_list = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list]
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      outputs = mprs.run()

      protein_df['sequence'] = [output.split("\t")[1] for output in outputs]
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 7. Multiple PDB/CIF Structures
    elif data_type == 'Multiple PDB/CIF Structures':
      # protein_list = [(uniprot_id, type, chain), ...]
      # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
      # uniprot2pdb(protein_list)
      protein_list = []
      for row_tuple in protein_df.itertuples(index=False):
        assert row_tuple.type in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
        protein_list.append(row_tuple)
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      outputs = mprs.run()

      protein_df['sequence'] = [output.split("\t")[1] for output in outputs]
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 13. Pair Multiple AA Sequences
    elif data_type == "Multiple pairs of AA Sequences":
      for i in ['1', '2']:
        for index, value in protein_df[f'sequence_{i}'].items():
          sa_seq = ''
          for aa in value:
            sa_seq += aa + '#'
          protein_df.at[index, f'sequence_{i}'] = sa_seq

        protein_df[f'name_{i}'] = f'name_{i}'
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 14. Pair Multiple SA Sequences
    elif data_type == "Multiple pairs of SA Sequences":
      for i in ['1', '2']:
        protein_df[f'name_{i}'] = f'name_{i}'
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    # 15. Pair Multiple UniProt IDs
    elif data_type == "Multiple pairs of UniProt IDs":
      for i in ['1', '2']:
        protein_list = protein_df.loc[:, f'sequence_{i}'].tolist()
        uniprot2pdb(protein_list)
        protein_df[f'name_{i}'] = protein_list
        protein_list = [(uniprot_id, "AF2", "A") for uniprot_id in protein_list]
        mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
        outputs = mprs.run()

        protein_df[f'sequence_{i}'] = [output.split("\t")[1] for output in outputs]
        protein_df[f'chain_{i}'] = 'A'

      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

    elif data_type ==  "Multiple pairs of PDB/CIF Structures":
      # columns: sequence_1, sequence_2, type_1, type_2, chain_1, chain_2, label, stage

      # protein_list = [(uniprot_id, type, chain), ...]
      # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
      # uniprot2pdb(protein_list)

      for i in ['1', '2']:
        protein_list = []
        for index, row in protein_df.iterrows():
          assert row[f"type_{i}"] in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
          row_tuple = (row[f"sequence_{i}"], row[f"type_{i}"], row[f"chain_{i}"])
          protein_list.append(row_tuple)
        mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
        outputs = mprs.run()

        # add name column, del type column
        protein_df[f'name_{i}'] = protein_df[f'sequence_{i}'].apply(lambda x: x.split('.')[0])
        protein_df.drop(f"type_{i}", axis=1, inplace=True)
        protein_df[f'sequence_{i}'] = [output.split("\t")[1] for output in outputs]

      # columns: name_1, name_2, chain_1, chain_2, sequence_1, sequence_2, label, stage
      protein_df.to_csv(csv_dataset_path, index=None)
      return csv_dataset_path

  else:
    # 0. Single AA Sequence
    if data_type == 'Single AA Sequence':
      input_seq = raw_data
      aa_seq = input_seq.value

      sa_seq = ''
      for aa in aa_seq:
          sa_seq += aa + '#'
      return sa_seq

    # 1. Single SA Sequence
    elif data_type == 'Single SA Sequence':
      input_seq = raw_data
      sa_seq = input_seq.value

      return sa_seq

    # 2. Single UniProt ID
    elif data_type == 'Single UniProt ID':
      input_seq = raw_data
      uniprot_id = input_seq.value


      protein_list = [(uniprot_id, "AF2", "A")]
      uniprot2pdb([protein_list[0][0]])
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      seqs = mprs.run()
      sa_seq = seqs[0].split('\t')[1]
      return sa_seq

    # 3. Single PDB/CIF Structure
    elif data_type == 'Single PDB/CIF Structure':
      uniprot_id = raw_data[0]
      struc_type = raw_data[1].value
      chain = raw_data[2].value

      protein_list = [(uniprot_id, struc_type, chain)]
      mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=2, return_results=True)
      seqs = mprs.run()
      assert len(seqs)>0, "Unable to convert to SA sequence. Please check the `type`, `chain`, and `.pdb/.cif file`."
      sa_seq = seqs[0].split('\t')[1]
      return sa_seq

    # 9. Pair Single AA Sequences
    elif data_type == "A pair of AA Sequences":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single AA Sequence', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 10. Pair Single SA Sequences
    elif data_type ==  "A pair of SA Sequences":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single SA Sequence', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 11. Pair Single UniProt IDs
    elif data_type ==  "A pair of UniProt IDs":
      input_seq_1, input_seq_2 = raw_data
      sa_seq1 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_1)
      sa_seq2 = get_SA_sequence_by_data_type('Single UniProt ID', input_seq_2)

      return (sa_seq1, sa_seq2)

    # 12. Pair Single PDB/CIF Structure
    elif data_type == "A pair of PDB/CIF Structures":
      uniprot_id1 = raw_data[0]
      struc_type1 = raw_data[1].value
      chain1 = raw_data[2].value

      protein_list1 = [(uniprot_id1, struc_type1, chain1)]
      mprs1 = MultipleProcessRunnerSimplifier(protein_list1, pdb2sequence, n_process=2, return_results=True)
      seqs1 = mprs1.run()
      sa_seq1 = seqs1[0].split('\t')[1]

      uniprot_id2 = raw_data[3]
      struc_type2 = raw_data[4].value
      chain2 = raw_data[5].value

      protein_list2 = [(uniprot_id2, struc_type2, chain2)]
      mprs2 = MultipleProcessRunnerSimplifier(protein_list2, pdb2sequence, n_process=2, return_results=True)
      seqs2 = mprs2.run()
      sa_seq2 = seqs2[0].split('\t')[1]
      return sa_seq1, sa_seq2




################################################################################
########################## Download predicted structures #######################
################################################################################
def uniprot2pdb(uniprot_ids, nprocess=20):
  from saprot.utils.downloader import AlphaDBDownloader

  os.makedirs(STRUCTURE_HOME, exist_ok=True)
  af2_downloader = AlphaDBDownloader(uniprot_ids, "pdb", save_dir=STRUCTURE_HOME, n_process=20)
  af2_downloader.run()



################################################################################
############### Form foldseek sequences by multiple processes ##################
################################################################################
# def pdb2sequence(process_id, idx, uniprot_id, writer):
#   from saprot.utils.foldseek_util import get_struc_seq

#   try:
#     pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
#     cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
#     if Path(pdb_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]
#     if Path(cif_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, cif_path, ["A"], process_id=process_id)["A"][-1]

#     writer.write(f"{uniprot_id}\t{seq}\n")
#   except Exception as e:
#     print(f"Error: {uniprot_id}, {e}")

# clear_output(wait=True)
# print("Installation finished!")

def pdb2sequence(process_id, idx, row_tuple, writer):

  # print("="*100)
  # print(row_tuple)
  # print("="*100)
  uniprot_id = row_tuple[0].split('.')[0]     #
  struc_type = row_tuple[1]                   # PDB or AF2
  chain = row_tuple[2]

  if struc_type=="AF2":
    plddt_mask= True
    chain = 'A'
  else:
    plddt_mask= False

  from saprot.utils.foldseek_util import get_struc_seq

  try:
    pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
    cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
    if Path(pdb_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    elif Path(cif_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [chain], process_id=process_id, plddt_mask=plddt_mask)[chain][-1]
    else:
      raise BaseException(f"The {uniprot_id}.pdb/{uniprot_id}.cif file doesn't exists!")
    writer.write(f"{uniprot_id}\t{seq}\n")

  except Exception as e:
    print(f"Error: {uniprot_id}, {e}")


pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
          "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
          "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
          "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
          "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

alphabet_list = list(ascii_uppercase+ascii_lowercase)


def convert_outputs_to_pdb(outputs):
	final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
	outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
	final_atom_positions = final_atom_positions.cpu().numpy()
	final_atom_mask = outputs["atom37_atom_exists"]
	pdbs = []
	outputs["plddt"] *= 100

	for i in range(outputs["aatype"].shape[0]):
		aa = outputs["aatype"][i]
		pred_pos = final_atom_positions[i]
		mask = final_atom_mask[i]
		resid = outputs["residue_index"][i] + 1
		pred = OFProtein(
		    aatype=aa,
		    atom_positions=pred_pos,
		    atom_mask=mask,
		    residue_index=resid,
		    b_factors=outputs["plddt"][i],
		    chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
		)
		pdbs.append(to_pdb(pred))
	return pdbs


# This function is copied from ColabFold!
def show_pdb(path, show_sidechains=False, show_mainchains=False, color="lddt"):
  file_type = str(path).split(".")[-1]
  if file_type == "cif":
    file_type == "mmcif"

  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
  view.addModel(open(path,'r').read(),file_type)

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    chains = len(get_chain_ids(path))
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view


def plot_plddt_legend(dpi=100):
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt


################################################################################
###############   Download file to local computer   ##################
################################################################################
def file_download(path: str):
  with open(path, "rb") as r:
    res = r.read()

  #FILE
  filename = os.path.basename(path)
  b64 = base64.b64encode(res)
  payload = b64.decode()

  #BUTTONS
  html_buttons = '''<html>
  <head>
  <meta name="viewport" content="width=device-width, initial-scale=1">
  </head>
  <body>
  <a download="{filename}" href="data:text/csv;base64,{payload}" download>
  <button class="p-Widget jupyter-widgets jupyter-button widget-button mod-warning">Download File</button>
  </a>
  </body>
  </html>
  '''

  html_button = html_buttons.format(payload=payload,filename=filename)
  display(HTML(html_button))

  # Automatically download file if the server is from google cloud.
  if root_dir == "/content":
    files.download(path)

################################################################################
############################ MODEL INFO #######################################
################################################################################
def get_base_model(adapter_path):
  adapter_config = Path(adapter_path) / "adapter_config.json"
  with open(adapter_config, 'r') as f:
    adapter_config_dict = json.load(f)
    base_model = adapter_config_dict['base_model_name_or_path']
    if 'Protein_Encoder_650M' in base_model:
      base_model = "ProTrekHub/Protein_Encoder_650M"
    elif 'Protein_Encoder_35M' in base_model:
      base_model = "ProTrekHub/Protein_Encoder_35M"
    else:
      raise RuntimeError("Please ensure the base model is \"Protein_650M\" or \"Protein_Encoder_35M\"")
  return base_model

def check_training_data_type(adapter_path, data_type):
  metadata_path = Path(adapter_path) / "metadata.json"
  if metadata_path.exists():
    with open(metadata_path, 'r') as f:
      metadata = json.load(f)
      required_training_data_type = metadata['training_data_type']
  else:
    required_training_data_type = "SA"

  if (required_training_data_type == "AA") and ("AA" not in data_type):
    print(Fore.RED+f"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with AA sequences."+Style.RESET_ALL)
    print(Fore.RED+f"The current data type ({data_type}) includes structural information, which will not be used for predictions."+Style.RESET_ALL)
    print()
    print('='*100)
  elif (required_training_data_type == "SA") and ("AA" in data_type):
    print(Fore.RED+f"This model ({adapter_path}) is trained on {required_training_data_type} sequences, and predictions work better with SA sequences."+Style.RESET_ALL)
    print(Fore.RED+f"The current data type ({data_type}) does not include structural information, which may lead to weak prediction performance."+Style.RESET_ALL)
    print(Fore.RED+f"If you only have the amino acid sequence, we strongly recommend using AF2 to predict the structure and generate a PDB file before prediction."+Style.RESET_ALL)
    print()
    print('='*100)

  return required_training_data_type

def mask_struc_token(sequence):
    return ''.join('#' if i % 2 == 1 and char.islower() else char for i, char in enumerate(sequence))

def get_num_labels_by_adapter(adapter_path):
    adapter_path = Path(adapter_path)

    if (adapter_path / 'adapter_model.safetensors').exists():
        file_path = adapter_path / 'adapter_model.safetensors'
        with safe_open(file_path, framework="pt") as f:
          if 'base_model.model.classifier.out_proj.bias' in f.keys():
              tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')
          elif 'base_model.model.classifier.bias' in f.keys():
              tensor = f.get_tensor('base_model.model.classifier.bias')
          else:
              raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    elif (adapter_path / 'adapter_model.bin').exists():
      file_path = adapter_path / 'adapter_model.bin'
      state_dict = torch.load(file_path)
      if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():
        tensor = state_dict['base_model.model.classifier.out_proj.bias']
      elif 'base_model.model.classifier.bias' in f.keys():
        tensor = state_dict['base_model.model.classifier.bias']
      else:
        raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    else:
        raise FileNotFoundError(f"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).")

    num_labels = list(tensor.shape)[0]
    return num_labels

def get_num_labels_and_task_type_by_adapter(adapter_path):
    adapter_path = Path(adapter_path)

    task_type = None
    if (adapter_path / 'adapter_model.safetensors').exists():
      file_path = adapter_path / 'adapter_model.safetensors'
      with safe_open(file_path, framework="pt") as f:
        if 'base_model.model.classifier.out_proj.bias' in f.keys():
          tensor = f.get_tensor('base_model.model.classifier.out_proj.bias')
        elif 'base_model.model.classifier.bias' in f.keys():
          task_type = 'token_classification'
          tensor = f.get_tensor('base_model.model.classifier.bias')
        else:
          raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    elif (adapter_path / 'adapter_model.bin').exists():
      file_path = adapter_path / 'adapter_model.bin'
      state_dict = torch.load(file_path)
      if 'base_model.model.classifier.out_proj.bias' in state_dict.keys():
        tensor = state_dict['base_model.model.classifier.out_proj.bias']
      elif 'base_model.model.classifier.bias' in f.keys():
        task_type = 'token_classification'
        tensor = state_dict['base_model.model.classifier.bias']
      else:
        raise KeyError(f"Neither 'base_model.model.classifier.out_proj.bias' nor 'base_model.model.classifier.bias' found in the file({file_path}).")

    else:
        raise FileNotFoundError(f"Neither 'adapter_model.safetensors' nor 'adapter_model.bin' found in the provided path({adapter_path}).")

    num_labels = list(tensor.shape)[0]
    if task_type != 'token_classification':
      if num_labels > 1:
        task_type = 'classification'
      elif num_labels == 1:
        task_type = 'regression'

    return num_labels, task_type

clear_output(wait=True)
print("Installation finished!")

# **2: Train and Share your Protein Model** <a name="train"></a>

You can **train** a model based on pre-trained ProTrek, or **continually train** a fine-tuned model in ProTrekHub.



<!-- ## Training Dataset

For the training dataset, **two additional columns** are required in the CSV file: `label` and `stage`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_AA_Sequences_data_format_training.png
?raw=true" height="200" width="400px" align="center">

### Column `label`

The content of column `label` depends on your **task type**:

| Task Type                         | Content in the Column                          |
|-----------------------------------|------------------------------------------------|
| Classification tasks              | Category index starting from zero              |
| Amino Acid Classification tasks   | A list of category indices for each amino acid |
| Regression tasks                  | Numerical values                               |

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/label_format.png?raw=true" height="300" width="800px" align="center">
<br>


### Column `stage`

The column `stage` indicate whether the sample is used for training, validation, or testing. Ensure your dataset includes samples for all three stages. The values are: `train`, `valid`, `test`.

<br>

### **Note:**

1. **Examples are available** at /content/SaprotHub/upload_files (if you connect to your local server, then the path is /SaprotHub/upload_files). Download to review their format, and then upload them for a trial.

2.  <a href="#get_sa">Here</a> you can **convert your data into SA Sequence** format.

3. <a href="#fa2csv">Here</a> you can **convert your .fa/.fasta file to a .csv file**, which corresponds to the data format for Multiple AA Sequences.

4. <a href="#split_dataset">Here</a> you can **randomly split your .csv dataset**, which means to add a `stage` column, where the ratio of `train`:`valid`:`test` is 8:1:1.

4. The maximum input length of the model is 1024, and protein sequences exceeding this length will only retain the first 1024 amino acids. -->


In [None]:
#@title **2.1: Train your Model** <a name="train"></a>

################################################################################
############################# ADVANCED CONFIG ##################################
################################################################################

# training config
GPU_batch_size = 0
accumulate_grad_batches = 0
num_workers = 2
seed = 20000812

# lora config
r = 8
lora_dropout = 0.0
lora_alpha = 16

# dataset config
val_check_interval=0.5
limit_train_batches=1.0
limit_val_batches=1.0
limit_test_batches=1.0


mask_struc_ratio=None

################################################################################
################################## MARKDOWN #################################
################################################################################
#@markdown ⚠️If you want to **interrupt** the training, **do not** click the run button again. Please refer to [here](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#interrupt-training-to-avoid-overfitting).

#@markdown > 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#21-Train-your-model)

if torch.cuda.is_available() is False:
  raise BaseException("Please refer to Section 1.1 to switch your Runtime to a GPU!")

################################################################################
################################## TASK CONFIG #################################
################################################################################
#@markdown # 1. Task
task_name = "demo" # @param {type:"string"}
task_type = "Protein-level Regression" # @param ["Protein-level Classification", "Protein-level Regression", "Residue-level Classification", "Protein-protein Classification", "Protein-protein Regression"]
original_task_type = task_type
task_type = task_type_dict[task_type]

if task_type in ["classification", 'token_classification', 'pair_classification']:

  print(Fore.BLUE+'Enter the number of category in your training dataset here:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              max=1000000,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>

################################################################################
#################################### MODEL CONFIG #####################################
################################################################################
#@markdown # 2. Model

base_model = "Official pretrained ProTrek (35M)" # @param ["Official pretrained ProTrek (35M)", "Official pretrained ProTrek (650M)", "Trained by yourself on ColabProTrek", "Shared by peers on ProTrekHub", "Saved in your local computer"]

# continue learning
if base_model in ["Trained by yourself on ColabProTrek", "Shared by peers on ProTrekHub", "Saved in your local computer"]:
  continue_learning = True
  adapter_combobox = select_adapter_from(task_type, use_model_from=base_model)
else:
  continue_learning = False

#@markdown <br>

################################################################################
################################### DATASET CONFIG ####################################
################################################################################
#@markdown # 3. Dataset

data_type = "Multiple AA Sequences" # @param ["SaprotHub Dataset","Multiple AA Sequences","Multiple UniProt IDs","Multiple pairs of AA Sequences","Multiple pairs of UniProt IDs"]
check_task_type_and_data_type(original_task_type, data_type)

raw_data = input_raw_data_by_data_type(data_type)

#@markdown <br>

################################################################################
################################### TRAIN CONFIG ####################################
################################################################################
#@markdown # 4. Training

batch_size = "Adaptive" # @param ["Adaptive", "1", "2", "4", "8", "16", "32", "64", "128", "256"]
max_epochs = 10 # @param ["10", "20", "50"] {type:"raw", allow-input: true}
learning_rate = 1.0e-4 # @param ["1.0e-3", "5.0e-4", "1.0e-4"] {type:"raw", allow-input: true}

#@markdown <br>


################################################################################
################################# CONFIG #######################################
################################################################################

from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

################################################################################
################################### TRAIN ####################################
################################################################################

def train(button):
  global base_model
  global GPU_batch_size
  global accumulate_grad_batches

  button.disabled = True
  button.description = 'Training...'
  button.button_style = ''

################################################################################
################################### DATASET CONFIRM ####################################
################################################################################
  csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
  new_path = check_column_label_and_stage(csv_dataset_path)
  from saprot.utils.construct_lmdb import construct_lmdb
  construct_lmdb(new_path, LMDB_HOME, task_name, task_type)
  lmdb_dataset_path = LMDB_HOME / task_name

################################################################################
################################### MODEL CONFIRM ####################################
################################################################################

  # base_model
  if continue_learning:
    adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value
    print(f"Training on an existing model: {adapter_path}")

    if base_model == "Shared by peers on ProTrekHub":
      if not adapter_path.exists():
        snapshot_download(repo_id=adapter_combobox.value, repo_type="model", local_dir=adapter_path)

    adapter_config_path = Path(adapter_path) / "adapter_config.json"
    assert adapter_config_path.exists(), f"Can't find {adapter_config_path}"
    with open(adapter_config_path, 'r') as f:
      adapter_config = json.load(f)
      base_model = adapter_config['base_model_name_or_path']

  elif base_model == "Official pretrained ProTrek (35M)":
    base_model = "ProTrekHub/Protein_Encoder_35M"

  elif base_model == "Official pretrained ProTrek (650M)":
    base_model = "ProTrekHub/Protein_Encoder_650M"

  elif base_model == "Official pretrained ProtT5 (2.8B)":
    base_model = "Rostlab/prot_t5_xl_uniref50"

  # model size and model name
  if base_model == "ProTrekHub/Protein_Encoder_650M":
    model_size = "650M"
    model_name = f"Model-{task_name}-{model_size}"
  elif base_model == "ProTrekHub/Protein_Encoder_35M":
    model_size = "35M"
    model_name = f"Model-{task_name}-{model_size}"
  elif base_model == "Rostlab/prot_t5_xl_uniref50":
    model_size = "2.8B"
    model_name = f"Model-{task_name}-{model_size}"

  config.setting.run_mode = "train"
  config.setting.seed = seed

################################################################################
################################# MODEL ########################################
################################################################################

  if task_type in ["classification", "token_classification", "pair_classification"]:
    config.model.kwargs.num_labels = num_of_categories.value

  config.model.model_py_path = model_type_dict[task_type]
  config.model.kwargs.config_path = base_model
  config.dataset.kwargs.tokenizer = base_model

  config.model.save_path = str(ADAPTER_HOME / f"{task_type}" / "Local" / model_name)

  if task_type in ["regression", "pair_regression"]:
    config.model.kwargs.extra_config = {}
    config.model.kwargs.extra_config.attention_probs_dropout_prob=0
    config.model.kwargs.extra_config.hidden_dropout_prob=0

  config.model.kwargs.lora_kwargs = EasyDict({
    "is_trainable": True,
    "num_lora": 1,
    "r": r,
    "lora_dropout": lora_dropout,
    "lora_alpha": lora_alpha,
    "config_list": []})
  if continue_learning:
    config.model.kwargs.lora_kwargs.config_list.append({"lora_config_path": adapter_path})

################################################################################
################################# DATASET ######################################
################################################################################

  config.dataset.dataset_py_path = dataset_type_dict[task_type]

  config.dataset.train_lmdb = str(lmdb_dataset_path / "train")
  config.dataset.valid_lmdb = str(lmdb_dataset_path / "valid")
  config.dataset.test_lmdb = str(lmdb_dataset_path / "test")

  # num_workers
  config.dataset.dataloader_kwargs.num_workers = num_workers

  # mask_struc
  # config.dataset.kwargs.mask_struc_ratio= mask_struc_ratio

  ################################################################################
  ######################## batch size ############################################
  ################################################################################
  def get_accumulate_grad_samples(num_samples):
      if num_samples > 3200:
          return 64
      elif 1600 < num_samples <= 3200:
          return 32
      elif 800 < num_samples <= 1600:
          return 16
      elif 400 < num_samples <= 800:
          return 8
      elif 200 < num_samples <= 400:
          return 4
      elif 100 < num_samples <= 200:
          return 2
      else:
          return 1

  # advanced config
  if (GPU_batch_size > 0) and (accumulate_grad_batches > 0):
    config.dataset.dataloader_kwargs.batch_size = GPU_batch_size
    config.Trainer.accumulate_grad_batches= accumulate_grad_batches

  elif (GPU_batch_size == 0) and (accumulate_grad_batches == 0):

    # batch_size
    if base_model == "westlake-repl/ProTrek_650M_UniRef50" and root_dir == "/content":
      GPU_batch_size = 1
    else:
      GPU_batch_size_dict = {
        "Tesla T4": 2,
        "NVIDIA L4": 2,
        "NVIDIA A100-SXM4-40GB": 4,
        }
      GPU_name = torch.cuda.get_device_name(0)
      GPU_batch_size = GPU_batch_size_dict[GPU_name] if GPU_name in GPU_batch_size_dict else 2

      if task_type in ["pair_classification", "pair_regression"]:
        GPU_batch_size = int(max(GPU_batch_size / 2, 1))

    config.dataset.dataloader_kwargs.batch_size = GPU_batch_size

    # accumulate_grad_batches
    if batch_size == "Adaptive":

      env = lmdb.open(config.dataset.train_lmdb, readonly=True)

      with env.begin() as txn:
        stat = txn.stat()
        num_samples = stat['entries']

      accumulate_grad_samples = get_accumulate_grad_samples(num_samples)

    else:
      accumulate_grad_samples = int(batch_size)

    accumulate_grad_batches = max(int(accumulate_grad_samples / GPU_batch_size), 1)

    config.Trainer.accumulate_grad_batches= accumulate_grad_batches

  else:
    raise BaseException(f"Please make sure `GPU_batch_size`({GPU_batch_size}) and `accumulate_grad_batches`({accumulate_grad_batches}) are both greater than zero!")

  ################################################################################
  ############################## TRAINER #########################################
  ################################################################################

  config.Trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"

  # epoch
  config.Trainer.max_epochs = max_epochs
  # test only: load the existing model
  if config.Trainer.max_epochs == 0 and continue_learning:
    config.model.save_path = config.model.kwargs.lora_kwargs.config_list[0]['lora_config_path']

  # learning rate
  config.model.lr_scheduler_kwargs.init_lr = learning_rate

  # trainer
  config.Trainer.limit_train_batches=limit_train_batches
  config.Trainer.limit_val_batches=limit_val_batches
  config.Trainer.limit_test_batches=limit_test_batches
  config.Trainer.val_check_interval=val_check_interval

  # strategy
  strategy = {
      # - deepspeed
      # 'class': 'DeepSpeedStrategy',
      # 'stage': 2

      # - None
      # 'class': None,

      # - DP
      # 'class': 'DataParallelStrategy',

      # - DDP
      # 'class': 'DDPStrategy',
      # 'find_unused_parameter': True
  }
  config.Trainer.strategy = strategy

  ################################################################################
  ############################## Run the task ####################################
  ################################################################################

  print('='*100)
  print(Fore.BLUE+f"Training task type: {task_type}"+Style.RESET_ALL)
  print(Fore.BLUE+f"Dataset: {lmdb_dataset_path}"+Style.RESET_ALL)
  print(Fore.BLUE+f"Base Model: {config.model.kwargs.config_path}"+Style.RESET_ALL)
  if continue_learning:
    print(Fore.BLUE+f"Existing model: {config.model.kwargs.lora_kwargs.config_list[0]['lora_config_path']}"+Style.RESET_ALL)
  print('='*100)
  pprint.pprint(config)
  print('='*100)

  from saprot.scripts.training import finetune
  finetune(config)


  ################################################################################
  ############################## Save the adapter ################################
  ################################################################################

  def add_training_data_type_to_config(metadata_path, training_data_type):
    if metadata_path.exists() is False:
      config_data = {
          'training_data_type': training_data_type
          }
      with open(metadata_path, 'w') as file:
          json.dump(config_data, file, indent=4)

    else:
      with open(metadata_path, 'r') as file:
          config_data = json.load(file)

      config_data['training_data_type'] = training_data_type

      with open(metadata_path, 'w') as file:
          json.dump(config_data, file, indent=4)

  metadata_path = Path(config.model.save_path) / "metadata.json"
  training_data_type = training_data_type_dict[data_type]
  add_training_data_type_to_config(metadata_path, training_data_type)

  print(Fore.BLUE)
  print(f"Model is saved to \"{config.model.save_path}\" on Colab Server")
  print(Style.RESET_ALL)


  adapter_zip = Path(config.model.save_path) / f"{model_name}.zip"
  !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "README.md" "metadata.json"
  # !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "adapter_model.bin" "README.md" "metadata.json"
  print("Click to download the model to your local computer")
  if adapter_zip.exists():
    # files.download(adapter_zip)
    file_download(adapter_zip)



  ################################################################################
  ############################### Modify README ##################################
  ################################################################################
  name = model_name
  description = '<slot name=\'description\'>'
  label_meanings = '<slot name=\'label_meanings\'>'

  with open(f'{config.model.save_path}/adapter_config.json', 'r') as f:
    lora_config = json.load(f)

  markdown = f'''
---

base_model: {base_model} \n
library_name: peft

---
\n

# Model Card for {name}
{description}

## Task type
{original_task_type}

## Model input type
{training_data_type_dict[data_type]} Sequence

## Label meanings
{label_meanings}

## LoRA config

- **r:** {lora_config['r']}
- **lora_dropout:** {lora_config['lora_dropout']}
- **lora_alpha:** {lora_config['lora_alpha']}
- **target_modules:** {lora_config['target_modules']}
- **modules_to_save:** {lora_config['modules_to_save']}

## Training config

- **optimizer:**
  - **class:** AdamW
  - **betas:** (0.9, 0.98)
  - **weight_decay:** 0.01
- **learning rate:** {config.model.lr_scheduler_kwargs.init_lr}
- **epoch:** {config.Trainer.max_epochs}
- **batch size:** {config.dataset.dataloader_kwargs.batch_size * config.Trainer.accumulate_grad_batches}
- **precision:** 16-mixed \n
'''

  # Write the markdown output to a file
  with open(f"{config.model.save_path}/README.md", "w") as file:
    file.write(markdown)


button_train = ipywidgets.Button(
    description='Start Training',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Apply',
    # icon='check' # (FontAwesome names without the `fa-` prefix)
    )
button_train.on_click(train)
button_train.layout.width = '300px'
display(button_train)

In [None]:
#@title **2.2: Login HuggingFace to Upload your model (Optional)** <a name="upload_model"></a>
################################################################################
###################### Login HuggingFace #######################################
################################################################################

from huggingface_hub import notebook_login
notebook_login()


In [None]:
#@title **2.3: Upload your Model (Optional)**

#@markdown > 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#22-Upload-your-model)

# #@markdown Your Huggingface adapter repository names follow the format `<username>/<task_name>`.

################################################################################
########################## Metadata  ###########################################
################################################################################
name = "demo_cls" # @param {type:"string"}
description = "This model is used for a demo classification task" # @param {type:"string"}



# #@markdown > 0: Nucleus <br>
# #@markdown > 1: Cytoplasm <br>
# #@markdown > 2: Extracellular <br>
# #@markdown > ... <br>
# #@markdown > 9: Peroxisome <br>

label_meanings = "A, B" #@param {type:"string"}

################################################################################
########################### Move Files  ########################################
################################################################################

from huggingface_hub import HfApi, Repository, ModelFilter

api = HfApi()

user = api.whoami()

if name == "":
  name = model_name
repo_name = user['name'] + '/' + name
local_dir = Path("/content/SaprotHub/model_to_push") / repo_name
local_dir.mkdir(parents=True, exist_ok=True)

repo_list = [repo.id for repo in api.list_models(filter=ModelFilter(author=user['name']))]
if repo_name not in repo_list:
  api.create_repo(repo_name, private=False)

repo = Repository(local_dir=local_dir, clone_from=repo_name)

command = f"cp {config.model.save_path}/* {local_dir}/"
subprocess.run(command, shell=True)

################################################################################
########################## Modify README  ######################################
################################################################################
import json

md_path = local_dir / "README.md"


if task_type in ["classification", "token_classification", "pair_classification"]:
    label_meanings_md = ''
    for index, label in enumerate(label_meanings.split(', ')):
      label_meanings_md += f'''
{index}: {label.strip()}
'''
    label_meanings = label_meanings_md

replace_data = {
    '<slot name=\'description\'>': description,
    '<slot name=\'label_meanings\'>': label_meanings
}

with open(md_path, "r") as file:
    content = file.read()

for key, value in replace_data.items():
    if value != "":
        content = content.replace(key, value)

with open(md_path, "w") as file:
    file.write(content)

################################################################################
########################## Upload Model  #######################################
################################################################################


repo.push_to_hub(commit_message="Upload adapter model")

# **3: Use ProTrek to Predict**

In [None]:
#@title **3.1: Classification&Regression Prediction** <a name="prediction"></a>

#@markdown > 📍Please see the [tutorial](https://github.com/westlake-repl/SaprotHub/wiki/ColabProTrek-&-ColabProtT5#31-Classification-Regression-Prediction)

from transformers import EsmTokenizer
import torch
import copy
import sys
from saprot.scripts.training import my_load_model

################################################################################
################################# TASK #########################################
################################################################################
#@markdown # 1. Task

task_type = "Protein-level Regression" # @param ["Protein-level Classification", "Protein-level Regression", "Residue-level Classification", "Protein-protein Classification", "Protein-protein Regression"]
original_task_type = task_type
task_type = task_type_dict[task_type]

if task_type in ["classification", 'token_classification', 'pair_classification']:

  print(Fore.BLUE+'The number of categories in your classification task:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              # max=10,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>


################################################################################
################################## MODEL #######################################
################################################################################
#@markdown # 2. Model

use_model_from = "Trained by yourself on ColabProTrek" # @param ["Trained by yourself on ColabProTrek", "Shared by peers on ProTrekHub", "Saved in your local computer", "Multi-models on ProTrekHub"]
if use_model_from == "Multi-models on ProTrekHub":
  multi_lora = True
else:
  multi_lora = False

adapter_input = select_adapter_from(task_type, use_model_from)
#@markdown <br>

################################################################################
################################ DATASET #######################################
################################################################################
#@markdown # 3. Dataset
data_type = "Single AA Sequence" # @param ["Single AA Sequence","Single UniProt ID","Multiple AA Sequences","Multiple UniProt IDs","A pair of AA Sequences","A pair of UniProt IDs","Multiple pairs of AA Sequences","Multiple pairs of UniProt IDs"]
check_task_type_and_data_type(original_task_type, data_type)

mode = "Multiple Sequences" if (data_type in data_type_list_multiple) else "Single Sequence"

raw_data = input_raw_data_by_data_type(data_type)


################################################################################
##################################### PREDICT ###################################
################################################################################
def predict(button):
  button.disabled = True
  button.description = 'Predicting...'
  button.button_style = ''

  print('\n')
  print('='*100)

  ##############################################################################
  ################################# MODEL ###################################
  ##############################################################################
  if multi_lora:
    if use_model_from == "Multi-models on ColabProTrek":
      config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / lora_config_path}) for lora_config_path in list(adapter_input.value)]
    elif use_model_from == "Multi-models on ProTrekHub":
      #1. get adapter_list
      repo_id_list = adapter_input.value.replace(" ", "").split(',')
      #2. download adapters
      for repo_id in repo_id_list:
        snapshot_download(repo_id=repo_id, repo_type="model", local_dir=ADAPTER_HOME / task_type / repo_id)
      config_list = [EasyDict({'lora_config_path': ADAPTER_HOME / task_type / repo_id}) for repo_id in repo_id_list]

    assert len(config_list) > 0, "Please select your models from the dropdown menu on the output of 3.1!"
    base_model = get_base_model(ADAPTER_HOME / task_type / config_list[0].lora_config_path)

    required_training_data_type_list = []
    for lora_config in config_list:
      required_training_data_type_list.append(check_training_data_type(lora_config.lora_config_path, data_type))
    # assert len(set(required_training_data_type_list)) == 1, f"Error: The input data types of these models are not identical: {required_training_data_type_list}"
    required_training_data_type = required_training_data_type_list[0]

    lora_kwargs = EasyDict({
      "is_trainable": False,
      "num_lora": len(config_list),
      "config_list": config_list
    })

  else:
    if use_model_from == "Shared by peers on ProTrekHub":
      snapshot_download(repo_id=adapter_input.value, repo_type="model", local_dir=ADAPTER_HOME / task_type / adapter_input.value)

    adapter_path = ADAPTER_HOME / task_type / adapter_input.value
    base_model = get_base_model(adapter_path)
    required_training_data_type = check_training_data_type(adapter_path, data_type)
    lora_kwargs = {
      "is_trainable": False,
      "num_lora": 1,
      "config_list": [{"lora_config_path": adapter_path}]
    }

  ##############################################################################
  ################################# DATASET ###################################
  ##############################################################################
  def transform_single_sa_to_aa(sequence):
    """
    Remove the # and lower alphabet in sequence
    """
    return ''.join(char for char in sequence if char.isupper() and char != "#")

  def transform_sa_to_aa(df):
    for column in df.columns:
      if 'sequence' in column:
        df[column] = df[column].apply(transform_single_sa_to_aa)
    return df

  if data_type in data_type_list_multiple:
    csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
    df = read_csv_dataset(csv_dataset_path)
    df = transform_sa_to_aa(df)
  else:
    single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)
    if task_type in ["pair_classification", "pair_regression"]:
      single_sa_seq[0] = transform_single_sa_to_aa(single_sa_seq[0])
      single_sa_seq[1] = transform_single_sa_to_aa(single_sa_seq[1])
      df = pd.DataFrame({
          'sequence_1': [single_sa_seq[0]],
          'sequence_2': [single_sa_seq[1]]
      })
      df = transform_sa_to_aa(df)
    else:
      single_sa_seq = transform_single_sa_to_aa(single_sa_seq)
      df = pd.DataFrame({
          'sequence': [single_sa_seq]
      })
      df = transform_sa_to_aa(df)

  # def mask_struc_token(sequence):
  #   return ''.join('#' if i % 2 == 1 and char.islower() else char for i, char in enumerate(sequence))

  # if (required_training_data_type == "AA") and ("AA" not in data_type):
  #   if 'sequence' in df.columns:
  #     df['sequence'] = df['sequence'].apply(mask_struc_token)
  #   elif 'sequence_1' in df.columns and 'sequence_2' in df.columns:
  #     df['sequence_1'] = df['sequence_1'].apply(mask_struc_token)
  #     df['sequence_2'] = df['sequence_2'].apply(mask_struc_token)

  ################################################################################
  ##################################### CONFIG ###################################
  ################################################################################
  from saprot.config.config_dict import Default_config
  config = copy.deepcopy(Default_config)

  # task
  if task_type in [ "classification", "token_classification", "pair_classification"]:
    config.model.kwargs.num_labels = num_of_categories.value
  # base model
  config.model.model_py_path = model_type_dict[task_type]
  config.model.kwargs.config_path = base_model
  # lora
  config.model.kwargs.lora_kwargs = lora_kwargs

  ################################################################################
  ################################### LOAD MODEL ##################################
  ################################################################################
  model = my_load_model(config.model)
  tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)

  ################################################################################
  ################################### INFO #######################################
  ################################################################################
  # clear_output(wait=True)
  print('\n')
  print('='*100)

  print(Fore.BLUE+f"Task Type: {original_task_type}"+Style.RESET_ALL)

  print(Fore.BLUE+f"Model ({use_model_from}):"+Style.RESET_ALL)
  if multi_lora:
    print(Fore.BLUE+f"  Base Model: {base_model}"+Style.RESET_ALL)
    print(Fore.BLUE+f"  Adapter:"+Style.RESET_ALL)
    for lora_config in lora_kwargs.config_list:
      print(Fore.BLUE+f"    {lora_config.lora_config_path}"+Style.RESET_ALL)
  else:
    print(Fore.BLUE+f"  Base Model: {base_model}"+Style.RESET_ALL)
    print(Fore.BLUE+f"  Adapter: {adapter_path}"+Style.RESET_ALL)

  print(Fore.BLUE+f'Dataset ({data_type}):' +Style.RESET_ALL)
  if mode == "Multiple Sequences":
    print(Fore.BLUE+f"  CSV Dataset Path: {csv_dataset_path}"+Style.RESET_ALL)
  else:
    if "A pair of" in data_type:
      print(Fore.BLUE+f"  Sequence 1: {single_sa_seq[0]}"+Style.RESET_ALL)
      print(Fore.BLUE+f"  Sequence 2: {single_sa_seq[1]}"+Style.RESET_ALL)
    else:
      print(Fore.BLUE+f"  Sequence: {single_sa_seq}"+Style.RESET_ALL)

  ################################################################################
  ################################### INFERENCE ##################################
  ################################################################################
  print()
  print('='*100)
  print(Fore.BLUE+"Prediction Result:"+Style.RESET_ALL)

  outputs_list=[]
  if task_type in ["pair_classification", "pair_regression"]:
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
      input_1 = tokenizer(row['sequence_1'], return_tensors="pt")
      input_1 = {k: v.to(device) for k, v in input_1.items()}
      input_2 = tokenizer(row['sequence_2'], return_tensors="pt")
      input_2 = {k: v.to(device) for k, v in input_2.items()}

      with torch.no_grad(): outputs = model(input_1, input_2)
      outputs_list.append(outputs)
  else:
    for index in tqdm(range(len(df))):
      seq = df['sequence'].iloc[index]
      inputs = tokenizer(seq, return_tensors="pt")
      inputs = {k: v.to(device) for k, v in inputs.items()}
      with torch.no_grad(): outputs = model(inputs)
      outputs_list.append(outputs)

  ################################################################################
  ################################### RESULT ##################################
  ################################################################################
  timestamp = str(datetime.now().strftime("%Y%m%d%H%M%S"))
  output_file = OUTPUT_HOME / f'output_{timestamp}.csv'

  if task_type == "pair_classification":
    softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]
    print()
    for index, output in enumerate(softmax_output_list):
      print(f"For Sequence Pair {index}, Category {output.index(max(output))}, Probability: {output}")
      df.loc[index, 'result'] = output.index(max(output))
      df.loc[index, 'probability'] = ', '.join(map(str, output))
    df.to_csv(output_file, index=False)

  elif task_type == "pair_regression":
    print()
    for index, output in enumerate(outputs_list):
      print(f"For Sequence Pair {index}, Value {output.cpu().item()}")
    df['score'] = [output.cpu().item() for output in outputs_list]
    df.to_csv(output_file, index=False)

  elif task_type == "classification":
    print()
    softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]
    for index, output in enumerate(softmax_output_list):
      print(f"For Sequence {index}, Category {output.index(max(output))}, Probability: {output}")
      df.loc[index, 'result'] = output.index(max(output))
      df.loc[index, 'probability'] = ', '.join(map(str, output))
    df.to_csv(output_file, index=False)

  elif task_type == "regression":
    print()
    for index, output in enumerate(outputs_list):
      print(f"For Sequence {index}, Value {output.item()}")
    df['score'] = [output.cpu().item() for output in outputs_list]
    df.to_csv(output_file, index=False)

  elif task_type == "token_classification":
    seq_prob_df_list = []
    softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]
    # print("The probability of each category:")
    for seq_index, seq in enumerate(softmax_output_list):
      seq_prob_df = pd.DataFrame(seq)[1:-1]
      # print('='*100)
      # print(f'Sequence {seq_index + 1}:')
      # print(seq_prob_df.to_string())
      seq_prob_df['seq_index'] = seq_index
      seq_prob_df['aa_index'] = seq_prob_df.index
      seq_prob_df['sequence'] = df.loc[seq_index, 'sequence']
      seq_prob_df_list.append(seq_prob_df)
    combined_df = pd.concat(seq_prob_df_list, ignore_index=False)
    combined_df.to_csv(output_file, index=True)

  print()
  print('='*100)
  print(Fore.BLUE+f"The prediction result is saved to {output_file} and your local computer."+Style.RESET_ALL)
  file_download(output_file)

################################################################################
#################################### BUTTON #################################
################################################################################
button_predict = ipywidgets.Button(
    description='Make Prediction',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Apply',
    # icon='check' # (FontAwesome names without the `fa-` prefix)
    )
button_predict.on_click(predict)
# button_predict.layout.width = '500px'
display(button_predict)
