# **ColabSaProt: Collaborative Protein Language Modeling**

This is the Colab version of [SaProt](https://github.com/westlake-repl/SaProt), a pre-trained protein language model designed for various downstream protein tasks.

**ColabSaProt** is a platform where **Protein Language Models(PLMs)** are more accessible and user-friendly for biologists, enabling effortless model training and sharing within the scientific community.

We've established the [SaProtHub](https://huggingface.co/SaProtHub) for storing and sharing models and datasets, where you can explore extensive collections for specific protein prediction tasks.

We hope ColabSaProt and SaProtHub can contribute to advancing biological research, fostering collaboration, and accelerating discoveries in the field. You can access [our paper](https://www.biorxiv.org/content/10.1101/2023.10.01.560349v2) for further details.

For detailed steps of each section, please refer to the <a href="#manual">manual</a>.









## ColabSaProt Content

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/content.png
?raw=true" height="400" width="800px" align="center">

<font color=red>**To view the content, please click on the first option in the left sidebar.**</font>



## SaProt Hub

Find awesome models and datasets for specific protein task on [SaProtHub](https://huggingface.co/SaProtHub)!

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/SaProtHub.png?raw=true" height="500" width="800px" align="center">

# 0: Instruction

Before you begin training and utilizing your model, here are some important details about **Task**, **Dataset** and **Basic Colab knowledge** that you need to be aware of.

<br>





## 0.1: Task

Different models are designed for different tasks, so it's essential to understand **which type your task belongs to**.

You can recognize your task type based on your task description and objectives.

<br>

<!-- ### Task Type

- **Classification Task**: classify protein sequences.
- **Regression Task**: predict the value of some property of a protein sequence.
- **Amino Acid Classification Task**: classify the amino acids in a protein sequence.  -->


| Task Type                             | Description                                                  |
| ------------------------------------- | ------------------------------------------------------------ |
| **Classification Task**               | Classify protein sequences.                                  |
| **Regression Task**                   | Predict the value of some property of a protein sequence.    |
| **Amino Acid Classification Task**    | Classify the amino acids in a protein sequence.              |
| **Mutational Effect Prediction Task** | Predict the mutational effect based on the wild type sequence and mutation information. |
| **Inverse Folding Prediction**        | Predict the residue sequence of a structure-aware sequence with masked amino acids (which could be all masked or partially masked). |


<br>

Here are some example tasks and their task type:

| Task Type | Example |
| --- | --- |
| **Classification Task** | **Subcellular Location Prediction**: predict which location category the protein belong to. |
| **Classification Task** | **Metal Ion Binding Detection**: predict whether there are metal ion–binding sites in the protein. |
| **Regression Task** | **Thermostability Prediction**: predict the thermostability value of a protein. |
| **Amino Acid Classification Task** | **Binding Site Detection**: predict whether the amino acid is a binding site or not. |

<!-- <br>

### Use your models or shared models on SaProtHub

You can use

- your trained model
- or shared models on SaProtHub
- pre-trained protein language model

to make some prediction -->


<br>

## 0.2: Dataset <a name="data_format"></a>

You can use your private data to train and predict. Below are the various data formats corresponding to different **data types**.


<br>

### Data Type

1. Single AA Sequence
2. Single SA Sequence
3. Single UniProt ID
4. Single PDB/CIF Structure
5. Multiple AA Sequences
6. Multiple SA Sequences
7. Multiple UniProt IDs
8. Multiple PDB/CIF Structures
9. Huggingface Dataset

<br>

### Data Format <a name='data_format'></a>

#### For `Single AA Sequence`, `Single SA Sequence`, and `Single UniProt ID` (first three data types)
An input box will appear after running the cell. Please enter the protein sequence in the required format.

<br>

####  For `Single PDB/CIF Structure` (fourth data type)
A file upload button will appear after running the cell. Please upload a .pdb or .cif file.



<br>

#### For `Multiple AA Sequences`, `Multiple SA Sequences`, `Multiple UniProt IDs` (fifth to seventh data types)
A file upload button will appear after running the cell. Please upload a .csv file and ensure that the column name in the .csv file is `Sequence`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_Sequences_data_format.png?raw=true" height="200" width="800px" align="center">

<br>
<br>

#### For `Multiple PDB/CIF Structures`
A file upload button will appear after running the cell. Please upload a .csv file containing three columns: `Sqeuence`, `type` and `chain`;

- `type`: Indicate whether the structure file is a real PDB structure or an AlphaFold 2 predicted structure. For AF2 (AlphaFold 2) structures, we will apply pLDDT masking. The value must be either "PDB" or "AF2".
- `chain`: For real PDB structures, since multiple chains may exist in one .pdb file, it is necessary to specify which chain is used. For AF2 structures, the chain is assumed to be A by default.

After successfully uploading the .csv file, a second file upload button will appear. Please upload a zip file containing all corresponding pdb/cif files.


<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_PDB_CIF_Structures_data_format.png?raw=true" height="200" width="500px" align="center">

<br>
<br>

#### For `Huggingface Dataset`
An input box will appear after running the cell. Please enter the the repo_id of the Huggingface Dataset. Find some datasets in [Official SaProtHub Repository](https://huggingface.co/SaProtHub).

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Huggingface_ID.png?raw=true" height="200" width="600px" align="center">

<br>
<br>

### SA(Structure-aware) Sequence

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/SA_Sequence.png?raw=true" height="300" width="600px" align="center">

## 0.3: Colab

<br>

### Cell Running status

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Cell_status.png?raw=true" height="400" width="600px" align="center">



# **1: Installation**


In [None]:
#@title 1.1: ⚠️ Switch your Runtime type to <font color=red>**GPU!!!**</font>

#@markdown You can check the current runtime type in <font color=red>**the upper right corner of the page**</font>. If the current runtime type is CPU, you need to <font color=red>**switch it to GPU (either the free T4 or the paid A100)**</font> for a better training experience.

#@markdown <img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Runtime.png?raw=true" height="300" width="400px" align="center">

#@markdown #### Please follow the steps below to switch the runtime to GPU:

#@markdown 1. Click the dropdown button
#@markdown 2. Select option "Change runtime type"
#@markdown 3. Select a GPU
#@markdown 4. Click "Save" button
#@markdown 5. <font color=red>Each time you switch the runtime, all code blocks need to be **re-executed**.</font>


#@markdown <img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Switch_Runtime.png?raw=true" height="400" width="800px" align="center">

In [None]:
#@title 1.2: Click the run button ▶️ to install SaProt

#@markdown (Please waiting for 2-8 minutes to install...)
################################################################################
########################### install saprot #####################################
################################################################################

try:
  import saprot
  print("SaProt is installed successfully!")
except ImportError:
  print("Installing SaProt...")
  !mkdir -p /content/saprot/LMDB
  !mkdir -p /content/saprot/bin
  !mkdir -p /content/saprot/output
  !mkdir -p /content/saprot/adapters/classification
  !mkdir -p /content/saprot/adapters/regression
  !mkdir -p /content/saprot/adapters/token_classification
  !mkdir -p /content/saprot/structures

  !pip install colorama --quiet
  !pip install gdown==v4.6.3 --force-reinstall --quiet
  !gdown https://drive.google.com/drive/folders/1ECKe5clJXs4POlScVggRQDrFo5HJpGBN?usp=drive_link -O /content/saprot/ --folder  --quiet && pip install /content/saprot/ColabSaProtSetup/saprot-0.4.5-py3-none-any.whl --quiet
  !chmod +x /content/saprot/ColabSaProtSetup/foldseek

  !rsync -a --remove-source-files /content/saprot/ColabSaProtSetup/upload_files /content/saprot
  !rsync -a --remove-source-files /content/saprot/ColabSaProtSetup/datasets /content/saprot
  !mv /content/saprot/ColabSaProtSetup/foldseek /content/saprot/bin/

################################################################################
################################################################################
################################## global ######################################
################################################################################
################################################################################

import ipywidgets
from google.colab import widgets
from pathlib import Path
import pandas as pd
import torch
import numpy as np

import copy
import os
from tqdm import tqdm
from datetime import datetime
from google.colab import files
import zipfile
from loguru import logger

import yaml
import argparse

from easydict import EasyDict

from colorama import init, Fore, Back, Style
from IPython.display import clear_output

import subprocess

from saprot.utils.mpr import MultipleProcessRunnerSimplifier
from huggingface_hub import snapshot_download
import json

DATASET_HOME = Path('/content/saprot/datasets')
ADAPTER_HOME = Path('/content/saprot/adapters')
STRUCTURE_HOME = Path("/content/saprot/structures")
LMDB_HOME = Path('/content/saprot/LMDB')
OUTPUT_HOME = Path('/content/saprot/output')
UPLOAD_FILE_HOME = Path('/content/saprot/upload_files')
FOLDSEEK_PATH = Path("/content/saprot/bin/foldseek")
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"

task_type_dict = {
  "Classify protein sequences (classification)" : "classification",
  "Classify each Amino Acid (amino acid classification), e.g. Binding site detection" : "token_classification",
  "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" : "regression",
}
model_type_dict = {
  "classification" : "saprot/saprot_classification_model",
  "token_classification" : "saprot/saprot_token_classification_model",
  "regression" : "saprot/saprot_regression_model",
}
dataset_type_dict = {
  "classification": "saprot/saprot_classification_dataset",
  "token_classification" : "saprot/saprot_token_classification_dataset",
  "regression": "saprot/saprot_regression_dataset",
}
class font:
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'

    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

    RESET = '\033[0m'

def get_adapters_list():

    adapters_list = []

    for file_path in ADAPTER_HOME.glob('**/adapter_config.json'):
        adapters_list.append(file_path.parent)

    return adapters_list


def show_adapters_info(adapters_list):
  grid = widgets.Grid(len(adapters_list)+1, 2, header_row=True, header_column=True)

  with grid.output_to(0, 0):
    print("ID")

  with grid.output_to(0, 1):
    print("Local Model")

  # with grid.output_to(0, 2):
  #   print("Adapter Path")

  for i in range(len(adapters_list)):
    with grid.output_to(i+1, 0):
      print(i)
    with grid.output_to(i+1, 1):
      print(adapters_list[i].stem)
    # with grid.output_to(i+1, 2):
    #   print(adapters_list[i])

def adapters_text(adapters_list):
  input = ipywidgets.Text(
    value=None,
    placeholder='Enter Huggingface Model ID',
    # description='Selected:',
    disabled=False)
  input.layout.width = '500px'
  display(input)

  return input

def adapters_dropdown(adapters_list):
  dropdown = ipywidgets.Dropdown(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Select a Local Model here',
    # description='Selected:',
    disabled=False)
  dropdown.layout.width = '500px'
  display(dropdown)

  return dropdown

def adapters_combobox(adapters_list):
  combobox = ipywidgets.Combobox(
    options=[f"{adapter_path.parent.stem}/{adapter_path.stem}" for index, adapter_path in enumerate(adapters_list)],
    value=None,
    placeholder='Enter Huggingface Model repository id or select a Local Model here',
    # description='Selected:',
    disabled=False)
  combobox.layout.width = '500px'
  display(combobox)

  return combobox

def select_adapter():
  adapters_list = get_adapters_list()
  print(Fore.BLUE+"Existing Models:"+Style.RESET_ALL)
  # print("="*100)
  # show_adapters_info(adapters_list)
  # print("="*100)
  return adapters_combobox(adapters_list)

def select_adapter_from(use_model_from):
  adapters_list = get_adapters_list()

  if use_model_from == 'Local Models':
    print(Fore.BLUE+"Local Model:"+Style.RESET_ALL)
    return adapters_dropdown(adapters_list)
  elif use_model_from == 'SaProtHub Models':
    print(Fore.BLUE+"Huggingface Model:"+Style.RESET_ALL)
    return adapters_text(adapters_list)

################################################################################
########################### download dataset ###################################
################################################################################
def download_dataset(task_name):
  import gdown
  import tarfile

  filepath = LMDB_HOME / f"{task_name}.tar.gz"
  download_links = {
    "ClinVar" : "https://drive.google.com/uc?id=1Le6-v8ddXa1eLJZFo7HPij7NhaBmNUbo",
    "DeepLoc_cls2" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "DeepLoc_cls10" : "https://drive.google.com/uc?id=1dGlojkCt1DwUXWiUk4kXRGRNu5sz2uxf",
    "EC" : "https://drive.google.com/uc?id=1VFLFA-jK1tkTZBVbMw8YSsjZqAqlVQVQ",
    "GO_BP" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_CC" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "GO_MF" : "https://drive.google.com/uc?id=1DGiGErWbRnEK8jmE2Jpb996By8KVDBfF",
    "HumanPPI" : "https://drive.google.com/uc?id=1ahgj-IQTtv3Ib5iaiXO_ASh2hskEsvoX",
    "MetalIonBinding" : "https://drive.google.com/uc?id=1rwknPWIHrXKQoiYvgQy4Jd-efspY16x3",
    "ProteinGym" : "https://drive.google.com/uc?id=1L-ODrhfeSjDom-kQ2JNDa2nDEpS8EGfD",
    "Thermostability" : "https://drive.google.com/uc?id=1I9GR1stFDHc8W3FCsiykyrkNprDyUzSz",
  }

  try:
    gdown.download(download_links[task_name], str(filepath), quiet=False)
    with tarfile.open(filepath, 'r:gz') as tar:
      tar.extractall(path=str(LMDB_HOME))
      print(f"Extracted: {filepath}")
  except Exception as e:
    raise RuntimeError("The dataset has not prepared.")

################################################################################
############################# upload file ######################################
################################################################################
def upload_file(upload_path):
  import shutil
  import os
  from pathlib import Path
  import sys

  upload_path = Path(upload_path)
  upload_path.mkdir(parents=True, exist_ok=True)
  basepath = Path().resolve()
  try:
    uploaded = files.upload()
    filenames = []
    for filename in uploaded.keys():
      filenames.append(filename)
      shutil.move(basepath / filename, upload_path / filename)
    if len(filenames) == 0:
      logger.info("The uploading process has been interrupted by the user.")
      raise RuntimeError("The uploading process has been interrupted by the user.")
  except Exception as e:
    logger.error("Upload file fail! Please click the button to run again.")
    raise(e)

  return upload_path / filenames[0]

################################################################################
############################ upload dataset ####################################
################################################################################

data_type_list = ["Single AA Sequence",
                  "Single SA Sequence",
                  "Single UniProt ID",
                  "Single PDB/CIF Structure",
                  "Multiple AA Sequences",
                  "Multiple SA Sequences",
                  "Multiple UniProt IDs",
                  "Multiple PDB/CIF Structures",
                  "Huggingface Dataset"]

def input_raw_data_by_data_type(data_type):
  print(Fore.BLUE+"Dataset: "+Style.RESET_ALL, end='')

  # 0-2. 0. Single AA Sequence, 1. Single SA Sequence, 2. Single UniProt ID
  if data_type in data_type_list[:3]:
    input_seq = ipywidgets.Text(
      value=None,
      placeholder=f'Enter {data_type} here',
      disabled=False)
    input_seq.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_seq)
    return input_seq

  # 3. Single PDB/CIF Structure
  elif data_type == data_type_list[3]:
    print(Fore.BLUE+"Please upload a .pdb/.cif file"+Style.RESET_ALL)
    pdb_file_path = upload_file(STRUCTURE_HOME)
    return pdb_file_path.stem

  # 4-7. Multiple Sequences
  elif data_type in data_type_list[4:8]:
    print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
    uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
    print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
    print("="*100)

    if data_type == data_type_list[7]:
      # upload and unzip PDB files
      print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
      pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
      if pdb_zip_path.suffix != ".zip":
        logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
        raise RuntimeError("The data type does not match.")
      print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
      print("="*100)

      import zipfile
      with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
        zip_ref.extractall(STRUCTURE_HOME)

    return uploaded_csv_path

  # 8. Huggingface Dataset
  elif data_type == data_type_list[8]:
    input_repo_id = ipywidgets.Text(
      value=None,
      placeholder=f'Enter {data_type} repository id here',
      disabled=False)
    input_repo_id.layout.width = '500px'
    print(Fore.BLUE+f"{data_type}"+Style.RESET_ALL)
    display(input_repo_id)
    return input_repo_id


def get_SA_sequence_by_data_type(data_type, raw_data):

  # 0. Single AA Sequence
  if data_type == data_type_list[0]:
    input_seq = raw_data
    aa_seq = input_seq.value

    sa_seq = ''
    for aa in aa_seq:
        sa_seq += aa + '#'
    return sa_seq

  # 1. Single SA Sequence
  if data_type == data_type_list[1]:
    input_seq = raw_data
    sa_seq = input_seq.value

    return sa_seq

  # 2. Single UniProt ID
  if data_type == data_type_list[2]:
    input_seq = raw_data
    uniprot_id = input_seq.value

    protein_list = [uniprot_id]
    uniprot2pdb(protein_list)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]
    return sa_seq

  # 3. Single PDB/CIF Structure
  if data_type == data_type_list[3]:
    uniprot_id = raw_data

    protein_list = [uniprot_id]
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    seqs = mprs.run()
    sa_seq = seqs[0].split('\t')[1]
    return sa_seq

  # raw_data = upload_files/xxx.csv
  if data_type in data_type_list[4:8]:
    uploaded_csv_path = raw_data
    csv_dataset_path = DATASET_HOME / uploaded_csv_path.name

  # 4. Multiple AA Sequences
  if data_type == data_type_list[4]:
    protein_df = pd.read_csv(uploaded_csv_path)
    for index, value in protein_df['Sequence'].items():
      sa_seq = ''
      for aa in value:
        sa_seq += aa + '#'
      protein_df.at[index, 'Sequence'] = sa_seq

    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 5. Multiple SA Sequences
  if data_type == data_type_list[5]:
    protein_df = pd.read_csv(uploaded_csv_path)
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 6. Multiple UniProt IDs
  if data_type == data_type_list[6]:
    protein_df = pd.read_csv(uploaded_csv_path)
    protein_list = protein_df.iloc[:, 0].tolist()
    uniprot2pdb(protein_list)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    outputs = mprs.run()

    protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # 7. Multiple PDB/CIF Structures
  if data_type == data_type_list[7]:
    protein_df = pd.read_csv(uploaded_csv_path)
    # protein_list = [(uniprot_id, type, chain), ...]
    # protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
    # uniprot2pdb(protein_list)
    protein_list = []
    for tuple_row in df.itertuples(index=False):
      assert tuple_row.type in ['PDB', 'AF2'],  "The type of structure must be either \"PDB\" or \"AF2\"!"
      protein_list.append(tuple_row)
    mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
    outputs = mprs.run()

    protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
    protein_df.to_csv(csv_dataset_path, index=None)
    return csv_dataset_path

  # # 7. Multiple PDB/CIF Structures (AF2)
  # if data_type == data_type_list[7]:
  #   protein_df = pd.read_csv(uploaded_csv_path)
  #   protein_list = [item.split('.')[0] for item in protein_df.iloc[:, 0].tolist()]
  #   # uniprot2pdb(protein_list)
  #   mprs = MultipleProcessRunnerSimplifier(protein_list, pdb2sequence, n_process=1, return_results=True)
  #   outputs = mprs.run()

  #   protein_df['Sequence'] = [output.split("\t")[1] for output in outputs]
  #   protein_df.to_csv(csv_dataset_path, index=None)
  #   return csv_dataset_path

  # 8. Huggingface Dataset
  if data_type == data_type_list[8]:
    input_repo_id = raw_data
    REPO_ID = input_repo_id.value

    if REPO_ID.startswith('/'):
      return Path(REPO_ID)

    snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=LMDB_HOME/REPO_ID)

    return LMDB_HOME/REPO_ID


# # return a SA Sequence or a csv dataset path
# def get_raw_dataset(data_type, raw_data):
#   if data_type in data_type_list[:3]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data.value)
#   elif data_type == data_type_list[3]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data)
#   elif data_type in data_type_list[4:8]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data)
#   elif data_type in data_type_list[8]:
#     raw_dataset = get_SA_sequence_by_data_type(data_type, raw_data.value)

#   return raw_dataset

# def upload_dataset(data_type):
#   print(Fore.BLUE+f"Please upload the .csv file which contains {data_type}"+Style.RESET_ALL)
#   uploaded_csv_path = upload_file(UPLOAD_FILE_HOME)
#   print(Fore.BLUE+"Successfully upload your .csv file!"+Style.RESET_ALL)
#   print("="*100)

#   # selected_csv_dataset = DATASET_HOME / f"[DATASET]{Path(uploaded_csv_path).stem}.csv"
#   # get_SASequence_by_data_type(data_type, uploaded_csv_path, selected_csv_dataset)
#   # get_SA_sequence_by_data_type(data_type, uploaded_csv_path)
#   # print()
#   # print("="*100)
#   # print(Fore.BLUE+"Successfully upload your dataset!"+Style.RESET_ALL)

#   return uploaded_csv_path



################################################################################
########################## Download predicted structures #######################
################################################################################
def uniprot2pdb(uniprot_ids, nprocess=20):
  from saprot.utils.downloader import AlphaDBDownloader

  os.makedirs(STRUCTURE_HOME, exist_ok=True)
  af2_downloader = AlphaDBDownloader(uniprot_ids, "pdb", save_dir=STRUCTURE_HOME, n_process=20)
  af2_downloader.run()


################################################################################
############### Form foldseek sequences by multiple processes ##################
################################################################################
# def pdb2sequence(process_id, idx, uniprot_id, writer):
#   from saprot.utils.foldseek_util import get_struc_seq

#   try:
#     pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
#     cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
#     if Path(pdb_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, ["A"], process_id=process_id)["A"][-1]
#     if Path(cif_path).exists():
#       seq = get_struc_seq(FOLDSEEK_PATH, cif_path, ["A"], process_id=process_id)["A"][-1]

#     writer.write(f"{uniprot_id}\t{seq}\n")
#   except Exception as e:
#     print(f"Error: {uniprot_id}, {e}")

# clear_output(wait=True)
# print("Installation finished!")

def pdb2sequence(process_id, idx, row_tuple, writer):

  uniprot_id = row_tuple.Sequence.split('.')[0] #
  struc_type = row_tuple.type                   # PDB or AF2
  chain = row_tuple.chain                       #


  from saprot.utils.foldseek_util import get_struc_seq

  try:
    pdb_path = f"{STRUCTURE_HOME}/{uniprot_id}.pdb"
    cif_path = f"{STRUCTURE_HOME}/{uniprot_id}.cif"
    if Path(pdb_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, pdb_path, [chain], process_id=process_id)[chain][-1]
    if Path(cif_path).exists():
      seq = get_struc_seq(FOLDSEEK_PATH, cif_path, [chain], process_id=process_id)[chain][-1]

    writer.write(f"{uniprot_id}\t{seq}\n")
  except Exception as e:
    print(f"Error: {uniprot_id}, {e}")

clear_output(wait=True)
print("Installation finished!")

Installation finished!


# **2: Train and Share your Protein Model**

## Training Dataset

For the training dataset, **two additional columns** are required in the CSV file: `label` and `stage`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_AA_Sequences_data_format_training.png
?raw=true" height="200" width="400px" align="center">

- `label`: The content of this column depends on your **task type**:

| Task Type                         | Content in the Column                          |
|-----------------------------------|------------------------------------------------|
| Classification tasks              | Category index starting from zero              |
| Amino Acid Classification tasks   | A list of category indices for each amino acid |
| Regression tasks                  | Numerical values                               |

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/label_format.png?raw=true" height="300" width="800px" align="center">
<br>

- `stage` indicate whether the sample is used for training, validation, or testing. Ensure your dataset includes samples for all three stages. The values are: `train`, `valid`, `test`.

<br>

**Note:**

1. Examples are available at /content/saprot/upload_files. Download to review their format, and then upload them for a trial.

2. <a href="#fa2csv">Here</a> you can convert your .fa/.fasta file to a .csv file, which corresponds to the data format for Multiple AA Sequences.

3. <a href="#split_dataset">Here</a> you can **randomly split your .csv dataset**, which means to add a `stage` column, where the ratio of `train`:`valid`:`test` is 8:1:1.

In [None]:
#@title 2.1: Task Config

################################################################################
################################## TASK CONFIG #################################
################################################################################
#@markdown # 1. Task
task_name = "demo" # @param {type:"string"}
task_objective = "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (amino acid classification), e.g. Binding site detection"]
task_type = task_type_dict[task_objective]

if task_type in ["classification", 'token_classification']:

  print(Fore.BLUE+'Enter the number of category in your training dataset here:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              max=1000000,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>



################################################################################
#################################### MODEL #####################################
################################################################################
#@markdown # 2. Model

##@markdown We use Parameter-Efficient Fine-Tuning Technique for model training. It enables us to store model weights in a small **adapter** without changing the original model weights during training. After training, you can get an adapter specific to your task.
##@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
##@markdown
##@markdown 1. Select a **base model** from the dropdown box `model_path` below.
##@markdown
##@markdown 2. If you want to **train on existing adapters**, check the box `use_your_data_to_train_on_an_existing_model` below. By running this cell, you will see an **adapter combobox**. We provide two ways to select your adapter:
##@markdown  - Select a **Local Models** from the combobox.
##@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtAdapters/DeepLoc_cls10_35M")
##@markdown
##@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtAdapters)
base_model = "Trained by yourself" # @param ["Official pretrained SaProt (35M)", "Official pretrained SaProt (650M)", "Trained by yourself", "Trained by peers"]
# base_model = "westlake-repl/SaProt_35M_AF2" # @param ["westlake-repl/SaProt_35M_AF2", "westlake-repl/SaProt_650M_AF2", "Trained by yourself", "Trained by peers"]
# print(Fore.BLUE+f"Model: {base_model}"+Style.RESET_ALL)

# use_your_data_to_train_on_an_existing_model = True # @param {type:"boolean"}
if base_model == "Official pretrained SaProt (35M)":
  base_model = "westlake-repl/SaProt_35M_AF2"
if base_model == "Official pretrained SaProt (650M)":
  base_model = "westlake-repl/SaProt_650M_AF2"

if base_model in ["Trained by yourself", "Trained by peers"]:
  use_your_data_to_train_on_an_existing_model = True
else:
  use_your_data_to_train_on_an_existing_model = False

if use_your_data_to_train_on_an_existing_model:
  if base_model == "Trained by yourself":
    use_model_from = 'Local Models'
  elif base_model == "Trained by peers":
    use_model_from = 'SaProtHub Models'
  adapter_combobox = select_adapter_from(use_model_from)

#@markdown <br>

################################################################################
################################### DATASET ####################################
################################################################################
#@markdown # 3. Dataset

data_type = "Multiple AA Sequences" # @param ["Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures", "Huggingface Dataset"]
mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

raw_data = input_raw_data_by_data_type(data_type)

#@markdown <br>


################################################################################
################################################################################
################################################################################
################################################################################
################################################################################




##@markdown Complete some task configs and run this cell to Finetune SaProt on your dataset. <br>

# def get_num_of_labels(selected_csv_dataset):
#   df = pd.read_csv(selected_csv_dataset)
#   num_of_labels = len(df['label'].unique())

#   return num_of_labels


##@markdown <br>

################################################################################

################################################################################
############################### custom config ##################################
################################################################################

##@markdown ---
##@markdown # <center>Training Task Config</center>


# num_of_categories = 10 # @param {type:"number"}
# #@markdown <font face="Consolas" size=2 color='gray'>(Ignoring `num_of_categories` if predicting a value)


  # print(Fore.BLUE+'It\'s normal not to receive feedback once inputting is finished. Let\'s move on to the next step.'+Style.RESET_ALL)


################################################################################
#@title 2.3: Select Model
################################################################################

# #@markdown We utilize **LoRA** (A Parameter-Efficient Fine-Tuning Technique), which allows us to store model weights into an small adapter without adjusting the original model weights during training.
# #@markdown

# #@markdown After training, you can obtain an adapter for your task.

##@markdown ---
##@markdown # <center>Model</center>



# if use_your_data_to_train_on_an_existing_model:
#   print(Fore.BLUE+f"Loaded Adapter: {adapter_combobox.value}"+Style.RESET_ALL)




In [None]:
#@title 2.2: Train your Model

################################################################################
############################## advance config ##################################
################################################################################

batch_size = 4 # @param ["1", "2", "4", "8", "16", "32", "64", "128", "256"] {type:"raw", allow-input: true}
max_epochs = 1 # @param ["10", "20", "50"] {type:"raw", allow-input: true}
learning_rate = 1.0e-3 # @param ["1.0e-3", "5.0e-4", "1.0e-4"] {type:"raw", allow-input: true}

limit_train_batches=1.0
limit_val_batches=1.0
limit_test_batches=1.0

val_check_interval=0.5

use_lora = True
num_workers = 2

mask_struc_ratio=None
# mask_struc_ratio=1.0

download_adapter_to_your_computer = True

################################################################################
################################# MARKDOWM #####################################
################################################################################

#@markdown - <font face="Consolas" size=2 color='gray'> `batch_size` depends on the number of training samples. If your training data set is large enough, we recommend using 32, 64,128,256, ..., others can be set to 8, 4, 2. (Note that you can not use a larger batch size if you the Colab default T4 GPU. Strongly suggest you subscribe to Colab Pro for an A100 GPU.)

# #@markdown |  Recommended batch size   | T4  |  A100   |
# #@markdown | ---                       | --- |  ---    |
# #@markdown | SaProt_35M_AF2            |  4  |    16   |
# #@markdown | SaProt_650M_AF2           |  -  |    8    |


#@markdown - <font face="Consolas" size=2 color='gray'>`max_epochs` refers to the maximum number of training iterations. A larger value needs more training time. The best model will be saved after each iteration.
#@markdown You can adjust `max_epochs` to control training duration. (Note that the max running time of colab is 12hrs for unsubscribed user or 24hrs for Colab Pro+ user) <br>
#@markdown

# download_adapter_to_your_computer = True #@param {type:"boolean"}
#@markdown - <font face="Consolas" size=2 color='gray'>`learning_rate` affects the convergence speed of the model.
#@markdown Through experimentation, we have found that `5.0e-4` is a good default value for base model `SaProt_650M_AF2` and `1.0e-3` for `SaProt_35M_AF2`.

################################################################################
################################# CONFIG #######################################
################################################################################

from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

config.setting.run_mode = "train"


################################################################################
################################# ADAPTER ######################################
################################################################################
config.model.kwargs.use_lora = use_lora

if base_model in ["Trained by yourself", "Trained by peers"]:

  adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value

  if not adapter_path.exists():
    snapshot_download(repo_id=adapter_combobox.value, repo_type="model", local_dir=adapter_path)

  adapter_config = Path(adapter_path) / "adapter_config.json"
  with open(adapter_config, 'r') as f:
    base_model = json.load(f)['base_model_name_or_path']

  config.model.kwargs.lora_config_path = adapter_path

else:
  config.model.kwargs.lora_config_path = None

if config.setting.run_mode == 'train':
  config.model.kwargs.lora_inference = False
if config.setting.run_mode == 'test':
  config.model.kwargs.lora_inference = True

################################################################################
################################# MODEL ########################################
################################################################################

if task_type in ["classification", "token_classification"]:
  # config.model.kwargs.num_labels = get_num_of_labels(selected_csv_dataset)
  config.model.kwargs.num_labels = num_of_categories.value



config.model.model_py_path = model_type_dict[task_type]

config.model.kwargs.config_path = base_model
config.dataset.kwargs.tokenizer = base_model

config.model.save_path = str(ADAPTER_HOME / f"{task_type}" / "Local" / f"{task_name}")


################################################################################
################################# DATASET ######################################
################################################################################

if data_type == data_type_list[8]:
  lmdb_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
else:
  csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)

  from saprot.utils.construct_lmdb import construct_lmdb
  construct_lmdb(csv_dataset_path, LMDB_HOME, task_name, task_type)
  lmdb_dataset_path = LMDB_HOME / task_name


config.dataset.dataset_py_path = dataset_type_dict[task_type]

config.dataset.train_lmdb = str(lmdb_dataset_path / "train")
config.dataset.valid_lmdb = str(lmdb_dataset_path / "valid")
config.dataset.test_lmdb = str(lmdb_dataset_path / "test")

#  batch size, num_workers
config.dataset.dataloader_kwargs.batch_size = batch_size
config.dataset.dataloader_kwargs.num_workers = num_workers

# mask_struc
# config.dataset.kwargs.mask_struc_ratio= mask_struc_ratio

################################################################################
############################## TRAINER #########################################
################################################################################

config.Trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"

config.Trainer.accumulate_grad_batches= int(64 / batch_size)

# epoch
config.Trainer.max_epochs = max_epochs
# test only: load the existing model
if config.Trainer.max_epochs == 0:
  config.model.save_path = config.model.kwargs.lora_config_path

# learning rate
config.model.lr_scheduler_kwargs.init_lr = learning_rate

# trainer
config.Trainer.limit_train_batches=limit_train_batches
config.Trainer.limit_val_batches=limit_val_batches
config.Trainer.limit_test_batches=limit_test_batches
config.Trainer.val_check_interval=val_check_interval

# strategy
strategy = {
    # - deepspeed
    # 'class': 'DeepSpeedStrategy',
    # 'stage': 2

    # - None
    # 'class': None,

    # - DP
    # 'class': 'DataParallelStrategy',

    # - DDP
    # 'class': 'DDPStrategy',
    # 'find_unused_parameter': True
}
config.Trainer.strategy = strategy

################################################################################
############################## CONFIG ##########################################
################################################################################


################################################################################
############################## Run the task ####################################
################################################################################

print('='*100)
print(Fore.BLUE+f"Training task type: {task_type}"+Style.RESET_ALL)
print(Fore.BLUE+f"Dataset: {lmdb_dataset_path}"+Style.RESET_ALL)
print(Fore.BLUE+f"Base Model: {config.model.kwargs.config_path}"+Style.RESET_ALL)
if use_your_data_to_train_on_an_existing_model:
  print(Fore.BLUE+f"Existing model: {config.model.kwargs.lora_config_path}"+Style.RESET_ALL)
print('='*100)

from saprot.scripts.training import finetune
print(config)
finetune(config)


################################################################################
############################## Save the adapter ################################
################################################################################

print(Fore.BLUE)
print(f"Model is saved to \"{config.model.save_path}\" on colab Server")
print(Style.RESET_ALL)

if download_adapter_to_your_computer:
  adapter_zip = Path(config.model.save_path) / f"{task_name}.adapter.zip"
  !cd $config.model.save_path && zip -r $adapter_zip "adapter_config.json" "adapter_model.safetensors" "adapter_model.bin" "README.md"
  # with zipfile.ZipFile(adapter_zip, 'w') as zipf:
  #   zip_files = [str(file_path) for file_path in Path(config.model.save_path).glob("*")]
  #   print(zip_files)
  #   for file in zip_files:
  #     zipf.write(file, Path(file).name)

  print("Downloading adapter to your local computer")
  if adapter_zip.exists():
    files.download(adapter_zip)

In [None]:
#@title **2.3: Login HuggingFace to upload your model (Optional)**
################################################################################
###################### Login HuggingFace #######################################
################################################################################

from huggingface_hub import notebook_login
notebook_login()


In [None]:
#@title **2.4: Upload Model (Optional)**

# #@markdown Your Huggingface adapter repository names follow the format `<username>/<task_name>`.

################################################################################
########################## Metadata  ###########################################
################################################################################
#@markdown You can add some description to your model.
model_name = "task_demo" # @param {type:"string"}
description = "This model is used for a demo task" # @param {type:"string"}

#@markdown For the classification model, please provide detailed information about the meanings of all labels.

#@markdown For example, in a Subcellular Localization Classification Task with 10 categories, label=0 means the protein is located in the Nucleus, label=1 means the protein is located in the Cytoplasm, and so on. The information should be provided as follows:

#@markdown `Nucleus, Cytoplasm, Extracellular, Mitochondrion, Cell.membrane, Endoplasmic.reticulum, Plastid, Golgi.apparatus, Lysosome/Vacuole, Peroxisome`


# #@markdown > 0: Nucleus <br>
# #@markdown > 1: Cytoplasm <br>
# #@markdown > 2: Extracellular <br>
# #@markdown > ... <br>
# #@markdown > 9: Peroxisome <br>

label_meanings = "" #@param {type:"string"}



################################################################################
########################### Move Files  ########################################
################################################################################

from huggingface_hub import HfApi, Repository, ModelFilter

api = HfApi()

user = api.whoami()

if model_name == "":
  model_name = task_name
repo_name = user['name'] + '/' + model_name
local_dir = Path("/content/saprot/model_to_push") / repo_name
local_dir.mkdir(parents=True, exist_ok=True)

repo_list = [repo.id for repo in api.list_models(filter=ModelFilter(author=user['name']))]
if repo_name not in repo_list:
  api.create_repo(repo_name, private=False)

repo = Repository(local_dir=local_dir, clone_from=repo_name)

command = f"cp {config.model.save_path}/* {local_dir}/"
subprocess.run(command, shell=True)

################################################################################
########################## Modify README  ######################################
################################################################################
import json

md_path = local_dir / "README.md"

if task_type in ["classification", "token_classification"]:
  label_meanings_md = ''
  for index, label in enumerate(label_meanings.split(', ')):
    label_meanings_md += f"{index}: {label} <br> "

  # print(label_meanings_md)
  description = description + "<br><br> The digital label means: <br>" + label_meanings_md

replace_data = {
    "<!-- Provide a quick summary of what the model is/does. -->": description
}

with open(md_path, "r") as file:
    content = file.read()

for key, value in replace_data.items():
    if value != "":
        content = content.replace(key, value)

# new_md_path = "README.md"
with open(md_path, "w") as file:
    file.write(content)

################################################################################
########################## Upload Model  #######################################
################################################################################


repo.push_to_hub(commit_message="Upload adapter model")

# **3: Use SaProt to Predict**

## 3.1: Classification&Regression Prediction

<br>


### Dataset

For the prediction dataset, **only** `Sequence` column is required in the CSV file.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_Sequences_data_format.png?raw=true" height="200" width="800px" align="center">

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_PDB_CIF_Structures_data_format.png?raw=true" height="200" width="500px" align="center">

You can refer to the <a href='#data_format'>instruction</a> for detailed data formats.

<br>


In [None]:
#@title 3.1.1: Task Config

from transformers import EsmTokenizer
import torch
import copy

################################################################################
################################# TASK #########################################
################################################################################
#@markdown # 1. Task

task_objective = "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein" # @param ["Classify protein sequences (classification)", "Predict protein fitness (regression), e.g. Predict the Thermostability of a protein", "Classify each Amino Acid (amino acid classification), e.g. Binding site detection"]
task_type = task_type_dict[task_objective]

if task_type in ["classification", 'token_classification']:

  print(Fore.BLUE+'The number of categories in your classification task:'+Style.RESET_ALL)
  num_of_categories = ipywidgets.BoundedIntText(
                                              # value=7,
                                              min=2,
                                              # max=10,
                                              step=1,
                                              # description='num_of_category: \n',
                                              disabled=False)
  num_of_categories.layout.width = "100px"
  display(num_of_categories)

#@markdown <br>


################################################################################
################################## MODEL #######################################
################################################################################
#@markdown # 2. Model

##@markdown As we use Parameter-Efficient Fine-Tuning Technique, which allows us to store model weights into an small adapter without adjusting the original model weights during training, it's necessary to specify both the original model and adapter for prediction.
##@markdown
##@markdown 1. Select a **base model**
##@markdown
##@markdown 2. By running this cell, you will see an **model combobox**. We provide two ways to select your adapter:
##@markdown  - Select a **local model** from the combobox.
##@markdown   - Enter a **huggingface repository name** to the combobox. (e.g. "SaProtHub/DeepLoc_cls10_35M")
##@markdown
##@markdown You can also find some officical adapters in [here](https://huggingface.co/SaProtHub)
# base_model = "westlake-repl/SaProt_35M_AF2" #@param ['westlake-repl/SaProt_35M_AF2', 'westlake-repl/SaProt_650M_AF2'] {allow-input:true}
use_model_from = "Local Models" # @param ["Local Models", "SaProtHub Models"]

# use_existing_model = True # @param {type:"boolean"}
# use_existing_model = True
# if use_existing_model:
#   adapter_combobox = select_adapter()

adapter_input = select_adapter_from(use_model_from)
#@markdown <br>

################################################################################
################################################################################
################################################################################

# # @markdown Please ensure that the selected task type aligns with the training task type of the model you intend to utilize.

## @markdown If you are conducting inference on a classification task, please ensure that the `num_of_category` matches the number of categories in the training dataset. Otherwise, you do not need to assign `num_of_category`.


##@markdown You have two options to provide your protein sequences:
##@markdown - **Single Sequence: Enter a single SA sequence** into the input box, you can get a SA Sequence by clicking <a href="#get_SA_seq">here</a>
##@markdown - **Multiple Sequences: Select a dataset**, you can upload a dataset from <a href="#upload_dataset">here</a>


##@markdown <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/InferenceFileFormat.png" height="256" align="center" style="height:256px">


# print(Fore.BLUE+f"Data type: {data_type}"+Style.RESET_ALL)


################################################################################
################################ DATASET #######################################
################################################################################
#@markdown # 3. Dataset
data_type = "Multiple PDB/CIF Structures" # @param ["Single AA Sequence", "Single SA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]
mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

raw_data = input_raw_data_by_data_type(data_type)


#@markdown <br>

In [None]:
#@title 3.1.2: Get your Result
from transformers import EsmTokenizer
import torch
import copy
import sys
from saprot.scripts.training import my_load_model


################################################################################
################################# 0. MARKDOWN ##################################
################################################################################


# @markdown Click the run button to make prediction.

# @markdown <font color="red">**Note that:**</font> When predicting a category, the index of categories starts from zero.

################################################################################
################################# 1. DATASET ##################################
################################################################################

if mode == "Multiple Sequences":
  csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)
else:
  single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)


################################################################################
################################# 2. MODEL ##################################
################################################################################
# base_model = "westlake-repl/SaProt_35M_AF2"
if use_model_from == "SaProtHub Models":
  snapshot_download(repo_id=adapter_input.value, repo_type="model", local_dir=ADAPTER_HOME / task_type / adapter_input.value)

adapter_path = ADAPTER_HOME / task_type / adapter_input.value
adapter_config = Path(adapter_path) / "adapter_config.json"
with open(adapter_config, 'r') as f:
  base_model = json.load(f)['base_model_name_or_path']


# if use_existing_model:
#   if adapter_combobox.value =='':
#     print("Please select a model!")
#     sys.exit()

#   if ". " in adapter_combobox.value:
#     adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value
#   else:
#     adapter_path = adapter_combobox.value

################################################################################
##################################### config ###################################
################################################################################
from saprot.config.config_dict import Default_config
config = copy.deepcopy(Default_config)

# task
if task_type in [ "classification", "token_classification"]:
  # config.model.kwargs.num_labels = num_of_categories.value
  config.model.kwargs.num_labels = num_of_categories.value

# base model
config.model.model_py_path = model_type_dict[task_type]
# config.model.save_path = model_save_path
config.model.kwargs.config_path = base_model

# lora
config.model.kwargs.lora_config_path = adapter_path
config.model.kwargs.use_lora = True
config.model.kwargs.lora_inference = True

################################################################################
################################### inference ##################################
################################################################################
from peft import PeftModelForSequenceClassification

model = my_load_model(config.model)
tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


clear_output(wait=True)

# print("#"*100)
print(Fore.BLUE+f"Inference task type: {task_type}"+Style.RESET_ALL)
if mode == "Multiple Sequences":
  print(Fore.BLUE+f"Dataset: {csv_dataset_path}"+Style.RESET_ALL)
else:
  print(Fore.BLUE+f"Dataset: {raw_data}"+Style.RESET_ALL)

print(Fore.BLUE+f"Model: {base_model} - {adapter_path}"+Style.RESET_ALL)
# if use_existing_model:
#   print(Fore.BLUE+f"Adapter: {adapter_path}"+Style.RESET_ALL)

outputs_list=[]

if mode == "Multiple Sequences":
  timestamp = str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
  output_file = OUTPUT_HOME / f'output_{timestamp}.csv'
  df = pd.read_csv(csv_dataset_path)
  for index in tqdm(range(len(df))):
    seq = df['Sequence'].iloc[index]
    inputs = tokenizer(seq, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(inputs)
    outputs_list.append(outputs)

  df['score'] = [output.cpu().tolist() for output in outputs_list]
  df.to_csv(output_file, index=False)
  files.download(output_file)

  print(Fore.BLUE+f"\nThe prediction result is saved to {output_file} and your local computer."+Style.RESET_ALL)

else:
  # print("You are making inference based on a sequence that you entered")
  inputs = tokenizer(single_sa_seq, return_tensors="pt")
  inputs = {k: v.to(device) for k, v in inputs.items()}
  outputs = model(inputs)
  outputs_list.append(outputs)

################################################################################
##################################### output ###################################
################################################################################

print()
print("#"*100)
print(Fore.BLUE+"outputs:"+Style.RESET_ALL)

if task_type == "classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=1).squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(softmax_output_list):
    print(f"For Sequence {index}, Prediction: Category {output.index(max(output))}, Probability: {output}")
elif task_type == "regression":
  output_list = [output.squeeze().tolist() for output in outputs_list]
  for index, output in enumerate(outputs_list):
    print(f"For Sequence {index}, Prediction: Value {output.item()}")
elif task_type == "token_classification":
  import torch.nn.functional as F
  softmax_output_list = [F.softmax(output, dim=-1).squeeze().tolist() for output in outputs_list]
  # print(softmax_output_list)
  print("The probability of each category:")
  for seq_index, seq in enumerate(softmax_output_list):
    seq_prob_df = pd.DataFrame(seq)
    print('='*100)
    print(f'Sequence {seq_index + 1}:')
    print(seq_prob_df[1:-1])




## 3.2: Mutational Effect Prediction

<br>

### Mutation Task
- Single-site or Multi-site mutagenesis
- Saturation mutagenesis

<br>

### Mutation Dataset

For `Single-site or Multi-site mutagenesis`, **one additional column** are required in the CSV file: `mutation`.

<img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Multiple_AA_Sequences_data_format_mutation.png
?raw=true" height="200" width="500px" align="center">

- `mutation` column contains the **mutation information**.

<br>

### Mutation Information

Here is the detail about the representation of **mutation information**: <a name="mutation info"></a>

| mode | mutation information|
| --- | --- |
| Single-site mutagenesis | H87Y |
| Multi-site mutagenesis | H87Y:V162M:P179L:P179R |

- For `Single-site mutagenesis`, we use a term like "H87Y" to denote the mutation, where the first letter represents the **original amino acid**, the number in the middle represents the **mutation site** (indexed starting from 1), and the last letter represents the **mutated amino acid**,
- For `Multi-site mutagenesis`, we use a colon ":" to connect each single-site mutations, such as "H87Y:V162M:P179L:P179R".

<!-- ### Prediction Result -->


<!-- ### How to use your model for Mutational Effect Prediction -->

<!--## 1. Input and Output

 You have four different combinations of **mutation task** and **mode** to choose from: -->

<!--
 |Combination| Input | Output |
 | --- | --- | --- |
 |`Single-site or Multi-site mutagenesis` + `Single Sequence`| Enter **a SA sequence** and **a mutation information**| a score of the mutation |
 |`Single-site or Multi-site mutagenesis` + `Multiple Sequences`| Select **a dataset** and upload **a .csv file containing mutation information**| a .csv file containing the scores of mutations |
 |`Saturation mutagenesis` + `Single Sequence`| Enter **a SA sequence**| a .csv file containing the scores of all mutation on every position of the sequence |
 |`Saturation mutagenesis` + `Multiple Sequences`| Select **a dataset**| a .zip file containing the .csv files of the Saturation mutagenesis on every sequence |
  -->

 <!-- <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/mutation_input_output.png" height="500" width="800px" align="center"> -->

<!-- ### 2. Format of the uploaded .csv file containing mutation information

For Multiple Sequences, you are required to **upload an additional .csv file** as your mutation information.
<font color=red>Please ensure that each mutation in the mutation CSV file corresponds to each Sequence in the dataset CSV file.</font>
 <img src="https://github.com/LUCKYDOGQAQ/ColabSaProt_dev/raw/main/MutationFormat.png" height="256" align="center" style="height:256px"> -->

In [None]:
#@title 3.2.1: Task Config

mutation_task = "Single-site or Multi-site mutagenesis" #@param ["Single-site or Multi-site mutagenesis", "Saturation mutagenesis"]

data_type = "Single AA Sequence" # @param ["Single AA Sequence", "Single SA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple SA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]
raw_data = input_raw_data_by_data_type(data_type)

mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

if mutation_task == "Single-site or Multi-site mutagenesis":
  if mode == "Single Sequence":
    input_mut = ipywidgets.Text(
      value=None,
      placeholder='Enter Single Mutation Information here',
      # description='SA Sequence:',
      disabled=False)
    print(Fore.BLUE+"Mutation:"+Style.RESET_ALL)
    input_mut.layout.width = '500px'
    display(input_mut)


In [None]:
#@title 3.2.2: Get your Result

################################################################################
################################# DATASET ###################################
################################################################################
if mode == "Single Sequence":
  seq = get_SA_sequence_by_data_type(data_type, raw_data)
else:
  dataset_csv_path = get_SA_sequence_by_data_type(data_type, raw_data)

################################################################################
################################# Task Info ####################################
################################################################################
base_model = "westlake-repl/SaProt_35M_AF2"

print(Fore.BLUE)
print(f"Mutation task: {mutation_task}")
print(f"Mode: {mode}")
print(f"Model: {base_model}")
if mode == "Multiple Sequences":
  print(Fore.BLUE+f"Dataset: {dataset_csv_path}"+Style.RESET_ALL)
else:
  print(Fore.BLUE+f"Dataset: {seq}"+Style.RESET_ALL)

print(Style.RESET_ALL)

print(f"Predicting...")
timestamp = datetime.now().strftime("%y%m%d%H%M%S")

################################################################################
################################# load model ###################################
################################################################################

from saprot.model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel

config = {
    "foldseek_path": None,
    "config_path": base_model,
    "load_pretrained": True,
}
model = SaprotFoldseekMutationModel(**config)
tokenizer = model.tokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


################################################################################
########################### Single Sequence ####################################
################################################################################
if mode == "Single Sequence":

  if mutation_task == "Single-site or Multi-site mutagenesis":
    mut = input_mut.value
    score = model.predict_mut(seq, mut)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)
    print(f"The score of mutation {mut} is {Fore.BLUE}{score}{Style.RESET_ALL}")

  if mutation_task=="Saturation mutagenesis":
    timestamp = datetime.now().strftime("%y%m%d%H%M%S")
    output_path = OUTPUT_HOME / f'{timestamp}_prediction_output.csv'

    mut_dicts = []
    for pos in range(1, int(len(seq) / 2)+1):
      mut_dict = model.predict_pos_mut(seq, pos)
      mut_dicts.append(mut_dict)

    mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
    df = pd.DataFrame(mut_list)
    df.to_csv(output_path, index=None)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)
    files.download(output_path)
    print(f"\n{Fore.BLUE}The result has been saved to {output_path} and your local computer.{Style.RESET_ALL}")

################################################################################
########################### Multiple Sequences #################################
################################################################################
if mode == "Multiple Sequences":

  dataset_df = pd.read_csv(dataset_csv_path)
  results = []

  if mutation_task=="Single-site or Multi-site mutagenesis":
    for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Predicting"):
     seq = row['Sequence']
     mut_info = row['mutation']
     results.append(model.predict_mut(seq, mut_info))

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)

    # result_df = pd.DataFrame()
    # result_df['Sequence'] = dataset_df['Sequence']
    # result_df['mutation'] = dataset_df['mutation']
    dataset_df['score'] = results

    output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}.csv"
    dataset_df.to_csv(output_path, index=None)
    files.download(output_path)
    print(f"{Fore.BLUE}The result has been saved to {output_path} and your local computer {Style.RESET_ALL}")

  else:
    for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Predicting"):
      seq = row['Sequence']
      mut_dicts = []
      for pos in range(1, int(len(seq) / 2)+1):
        mut_dict = model.predict_pos_mut(seq, pos)
        mut_dicts.append(mut_dict)
      mut_list = [{'mutation': key, 'score': value} for d in mut_dicts for key, value in d.items()]
      result_df = pd.DataFrame(mut_list)
      results.append(result_df)

    print()
    print("="*100)
    print(Fore.BLUE+"Output:"+Style.RESET_ALL)

    zip_files = []
    for i in range(len(results)):
      output_path = OUTPUT_HOME / f"{timestamp}_prediction_output_{Path(dataset_csv_path).stem}_Sequence{i+1}.csv"
      results[i].to_csv(output_path, index=None)
      zip_files.append(output_path)

    # zip and download zip to local computer
    zip_path = OUTPUT_HOME / f"{timestamp}_{Path(dataset_csv_path).stem}.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in zip_files:
            zipf.write(file, os.path.basename(file))
    files.download(zip_path)
    print(f"{Fore.BLUE}The result has been saved to {zip_path} and your local computer{Style.RESET_ALL}")

## 3.3: Inverse Folding Prediction

Predict the amino acid sequence from protein backbone structure.

<br>

### Dataset

The protein backbone structure should be provided in .pdb/.cif file format.

<br>

<!-- Predict the residue sequence of a structure-aware sequence with masked amino acids (which could be all masked or partially masked).

<br>

### Dataset

Enter a **SA sequence with masked amino acids** into the `sa_seq` input box.

<br>

For example,
**input** is a SA Sequence with masked amino acids:

`#d#v#v#v#p#p#p#p#a#p#a#q#k#k#k#k#w`

and the **output** predicted by model is an AA Sequence:

`MEELGLPDLPPGGVVVV`.

<br> -->


In [None]:
#@title 3.3.1: Upload .pdb/.cif structure file

#@markdown After clicking the run button, an upload button will appear for you to upload your .pdb/.cif structure file.

#@markdown After uploading is finished, the .pdb/.cif structure will be transformed into the corresponding Amino Acid Sequence and Structure (3Di) Sequence.

#@markdown You can **mask specific amino acids** in the AA sequence with '#' at certain positions, allowing the model to make predictions for those masked amino acids.

data_type = "Single PDB/CIF Structure"

raw_data = input_raw_data_by_data_type(data_type)

sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)

aa_seq = sa_seq[0::2]
struc_seq = sa_seq[1::2]

# masked_sa_seq = ''
# for s in sa_seq[1::2]:
#   masked_sa_seq += '#' + s

clear_output(wait=True)

################################################################################
################################################################################
################################################################################

input_aa_seq = ipywidgets.Text(
      value=aa_seq,
      placeholder='Enter Amino Acid Sequence here',
      disabled=False)
print(Fore.BLUE+"Amino Acid Sequence:"+Style.RESET_ALL)
input_aa_seq.layout.width = '500px'
display(input_aa_seq)

input_struc_seq = ipywidgets.Text(
  value=struc_seq,
  placeholder='Enter Structure Sequence here',
  disabled=False)
print(Fore.BLUE+"Structure Sequence:"+Style.RESET_ALL)
input_struc_seq.layout.width = '500px'
display(input_struc_seq)


In [None]:
#@title 3.3.2: Predict Amino Acid Sequence

#@markdown Click the run button to get the Amino Acid Sequence


################################################################################
############################### Dataset ########################################
################################################################################

masked_aa_seq = input_aa_seq.value
masked_struc_seq = input_struc_seq.value
masked_sa_seq = ''.join(a + b for a, b in zip(masked_aa_seq, masked_struc_seq))

################################################################################
############################### Model ##########################################
################################################################################
base_model = "westlake-repl/SaProt_650M_AF2"

config = {
    "config_path": base_model,
    "load_pretrained": True,
}
from saprot.model.esm.saprot_if_model import SaProtIFModel
model = SaProtIFModel(**config)
tokenizer = model.tokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

################################################################################
############################### Predict ########################################
################################################################################

pred_aa_seq = model.predict(masked_sa_seq)

print("#"*100)
print(Fore.BLUE+"outputs:"+Style.RESET_ALL)
print(pred_aa_seq)


# **4: (Optional) Data Preparation**

## 4.1: Get Structure-Aware Sequence

In [None]:
#@title 4.1.1: Input

################################################################################
################################ input #########################################
################################################################################
data_type = "Multiple PDB/CIF Structures" # @param ["Single AA Sequence", "Single UniProt ID", "Single PDB/CIF Structure", "Multiple AA Sequences", "Multiple UniProt IDs", "Multiple PDB/CIF Structures"]
mode = "Multiple Sequences" if data_type in data_type_list[4:8] else "Single Sequence"

if data_type == data_type_list[7]:
    # upload and unzip PDB files
    print(Fore.BLUE+f"Please upload your .zip file which contains {data_type} files"+Style.RESET_ALL)
    pdb_zip_path = upload_file(UPLOAD_FILE_HOME)
    if pdb_zip_path.suffix != ".zip":
      logger.error("The data type does not match. Please click the run button again to upload a .zip file!")
      raise RuntimeError("The data type does not match.")
    print(Fore.BLUE+"Successfully upload your .zip file!"+Style.RESET_ALL)
    print("="*100)

    import zipfile
    with zipfile.ZipFile(pdb_zip_path, 'r') as zip_ref:
      file_names = zip_ref.namelist()
      zip_ref.extractall(STRUCTURE_HOME)


    uploaded_csv_path = UPLOAD_FILE_HOME / f"{pdb_zip_path.stem}.csv"
    df = pd.DataFrame(file_names, columns=['Sequence'])
    df.to_csv(uploaded_csv_path, index=False)
    raw_data = uploaded_csv_path

else:
  raw_data = input_raw_data_by_data_type(data_type)

# input_raw_data_by_data_type -> raw_data -> get_SA_sequence_by_data_type -> single_input_seq / csv_dataset_path

In [None]:
#@title 4.1.2: Output
#@markdown Click the run button to get the SA Sequence.
sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)

clear_output(wait=True)

if mode == "Single Sequence":
  print(f"Amino Acid Sequence: {sa_seq[0::2]}")
  print(f"Structure Sequence: {sa_seq[1::2]}")
  print("="*100)
  print(Fore.BLUE  + "The Structure-Aware Sequence is here, double click to select and copy it:" + Style.RESET_ALL)
  print(sa_seq)

elif mode == "Multiple Sequences":
  print(Fore.BLUE  + "The Structure-Aware Sequences are saved in a .csv file here:" + Style.RESET_ALL)
  print(sa_seq)
  files.download(sa_seq)


## 4.2: Convert `.fa/.fasta` file to `.csv` file in the data format of "Multiple AA Sequences"

In [None]:
#@title 4.2.1: `.fa/.fasta` -> Multiple AA Sequences `.csv` <a name="fa2csv"></a>
from Bio import SeqIO
import numpy as np

aa_seq_dict = { "Sequence": [],
                # "label": [],
                # "stage":[]
                }

fa_file_path = upload_file(UPLOAD_FILE_HOME)
assert Path(fa_file_path).name.split('.')[1] in ['fa', 'fasta'], "Please upload a .fa or .fasta file."
with fa_file_path.open("r") as fa:
  for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
      aa_seq_dict["Sequence"].append(str(record.seq))

fa_df = pd.DataFrame(aa_seq_dict)
print(fa_df[5:])

csv_file_path = f'/content/saprot/upload_files/{fa_file_path.stem}.csv'
fa_df.to_csv(csv_file_path, index=None)
files.download(csv_file_path)

################################################################################
############################ .fa 2 .csv and split ##############################
################################################################################

# automatically_split_dataset = False # @param {type:"boolean"}
# split = ['train', 'valid', 'test']

# aa_seq_dict = { "Sequence": [],
#                 "label": [],
#                 "stage":[]}



# if automatically_split_dataset:

#   fa_file_path = upload_file(UPLOAD_FILE_HOME)
#   label = fa_file_path.stem

#   with fa_file_path.open("r") as fa:
#       for record in tqdm(SeqIO.parse(fa, 'fasta'), leave=True):
#           aa_seq_dict["Sequence"].append(str(record.seq))
#           aa_seq_dict["label"].append(label)
#   weights = [0.8, 0.1, 0.1]
#   aa_seq_dict["stage"] = np.random.choice(split, size=len(aa_seq_dict["Sequence"]), p=weights).tolist()

# else:
#   for i in range(3):
#     print(Fore.BLUE+f"Please upload a .fa file as your {split[i]} dataset")
#     fa_file_path = upload_file(UPLOAD_FILE_HOME)
#     label = fa_file_path.stem

#     with fa_file_path.open("r") as fa:
#         for record in tqdm(SeqIO.parse(fa, 'fasta')):
#             aa_seq_dict["Sequence"].append(str(record.seq))
#             aa_seq_dict["label"].append(label)
#             aa_seq_dict["stage"].append(split[i])

#     print()
#     print("="*100)

# fa_df = pd.DataFrame(aa_seq_dict)
# timestamp = datetime.now().strftime("%y%m%d%H%M%S")
# fa_df.to_csv(f'/content/saprot/upload_files/{timestamp}.csv', index=None)
# files.download(f'/content/saprot/upload_files/{timestamp}.csv')
# print(fa_df[5:])

## 4.3: Dataset Split

Please click the run button to upload your .csv dataset

In [None]:
#@title 4.3.1: Randomly split your .csv dataset <a name="split_dataset"></a>

csv_dataset_path = upload_file(UPLOAD_FILE_HOME)
dataset_df = pd.read_csv(csv_dataset_path)

split = ['train', 'valid', 'test']
split_ratio = [0.8, 0.1, 0.1]

if ('stage' not in dataset_df.columns) or (dataset_df["stage"].nunique()<3):
  dataset_df["stage"] = np.random.choice(split, size=len(dataset_df), p=split_ratio).tolist()

dataset_df.to_csv(csv_dataset_path, index=None)


## 4.4: Multiple AA Sequences & Mutation Information -> Multiple AA Sequences

In [None]:
#@title 4.4.1: Multiple AA Sequences & Mutation Information ->  Multiple AA Sequences
csv_dataset_path = upload_file(UPLOAD_FILE_HOME)
dataset_df = pd.read_csv(csv_dataset_path)

mutation_aa_seq_list = []

def seq_mut_2_seq(seq, mut):
  if mut == ' ' or mut is None:
    return seq

  seq_list = list(seq)
  for m in mut.split(":"):
    pos = int(m[1:-1])
    mut_aa = m[-1]

    seq_list[pos-1] = mut_aa

  return ''.join(seq_list)


for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Processing..."):
  seq = row['Sequence']
  mut_info = row['mutation']

  mut_seq = seq_mut_2_seq(seq, mut_info)
  mutation_aa_seq_list.append(mut_seq)

mutation_aa_seq_df = pd.DataFrame(mutation_aa_seq_list, columns=['Sequence'])
mutation_aa_seq_df.to_csv(csv_dataset_path, index=False)

csv_dataset_path = upload_file(UPLOAD_FILE_HOME)
dataset_df = pd.read_csv(csv_dataset_path)

# def seq_mut_2_seq(seq, mut):
#     if pd.isnull(mut):  # 如果突变为空，直接返回原始序列
#         return seq
#     else:
#         seq_list = list(seq)
#         for m in mut.split(":"):
#             pos = int(m[1:-1])
#             mut_aa = m[-1]

#             seq_list[pos-1] = mut_aa

#         return ''.join(seq_list)

# for index, row in tqdm(dataset_df.iterrows(), total=len(dataset_df), leave=False, desc=f"Processing..."):
#     seq = row['Sequence']
#     mut_info = row['mutation']

#     mut_seq = seq_mut_2_seq(seq, mut_info)
#     mutation_aa_seq_list.append(mut_seq)

# mutation_aa_seq_df = pd.DataFrame(mutation_aa_seq_list, columns=['Sequence'])
# mutation_aa_seq_df.to_csv(csv_dataset_path, index=False)



# Manual <a name="manual"></a>





## How to train and share your model





### Train your Model

#### Step 1

Complete the input and selection of Task Configs, and then click the run button.

- `task_name` is the name of the training task you're working on.
- `task_objective` describes the goal of your task, like sorting protein sequences into categories or predicting the values of some protein properties.
- `base_model` is the base model you use for training. By default, it's set to the officially pretrained SaProt, but you can use models either retrained (by yourself) by ColabSaProt or shared on [SaProtHub](https://huggingface.co/SaProtHub). For example, you can choose `Trained-by-peers` with your own data if you want to retrain on Saprot models shared by others.  There are a wide range of retrained models available on [SaProtHub](https://huggingface.co/SaProtHub).
- `data_type` indicates the kind of data you're using, which is determined by the dataset file you upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-1-1.png?raw=true)

<!-- <img src="https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-1-1.png?raw=true" height="400" width="800px" align="center"> -->

#### Step 2

After clicking the "Run" button, additional input boxes will appear.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-2-1.png?raw=true)

Complete the input of additional information and upload files. (Note: Do not click the "Run" button of the next cell before completing the input and upload.)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-2-2.png?raw=true)

If you want to train on an existing model, choose "Existing Models with your data" as `base_model` at step 1, and then "Existing model" input box will appear. Enter a huggingface model id or select a local model.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-2-3.png?raw=true)

#### Step 3

Complete the input of training configs, and then click the "Run" button to start training.

- `batch_size` depends on the number of training samples. If your training data set is large enough, we recommend using 32, 64,128,256, ..., others can be set to 8, 4, 2. (Note that you can not use a larger batch size if you the Colab default T4 GPU. Strongly suggest you subscribe to Colab Pro for an A100 GPU.)
- `max_epochs` refers to the maximum number of training iterations. A larger value needs more training time. The best model will be saved after each iteration. You can adjust `max_epochs` to control training duration. (Note that the max running time of Colab is 12hrs for unsubscribed user or 24hrs for Colab Pro+ user)
- `learning_rate` affects the convergence speed of the model. Through experimentation, we have found that `5.0e-4` is a good default value for base model `SaProt_650M_AF2` and `1.0e-3` for `SaProt_35M_AF2`.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-3-1.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-3-2.png?raw=true)

#### Step 4

You can monitor the training process by these plots. After training, check the training results and the saved model.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-4-1.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/train-4-2.png?raw=true)

### (Optional) Upload your Model to Huggingface:

#### Step 1

Click the "Run" button, the Hugging Face login interface will appear. Enter the token and click the "Login" button.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/upload-1-1.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/upload-1-2.png?raw=true)

#### Step 2

Enter the model name and model description, and then click the button to upload the model. You can check your model by the link.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/upload-2-1.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/upload-2-2.png?raw=true)

## How to use your model for prediction





### Classification&Regression prediction task

#### Step 1

Complete the input and selection of Task Configs, and then click the run button.

- `task_objective` describes the goal of your task, like sorting protein sequences into categories or predicting the values of some protein properties.
- `use_model_from` depends on whether you want to use a local model or a Huggingface model. If you choose `SaProtHub Models`, please enter the Hugging Face model ID in the input box. If you choose `Local Model`, simply select your local model from the options. Additionally, there's a wide range of models available on SaProtHub.
- `data_type` indicates the kind of data you're using, which determines the dataset file you should upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-1-1.png?raw=true)

#### Step 2

After clicking the "Run" button, additional input boxes and upload button will appear.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-2-1.png?raw=true)

Complete the input of additional information and upload files. (Note: Do not click the "Run" button of the next cell before completing the input and upload.)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-2-2.png?raw=true)

#### Step 3

Click the run button to start predicting.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-3-1.png?raw=true)

Check your results after finishing prediction.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/cls_regr-3-2.png?raw=true)

### Mutational effect prediction task

#### Step 1

Complete the selection of Task Configs, and then click the run button.

- `mutation_task` indicates the type of mutation task. You can choose from `Single-site or Multi-site mutagenesis` and `Saturation mutagenesis`.
- `data_type` indicates the kind of data you're using, which determines the dataset file you should upload. You can find more details about the formats for different types of data in the provided <a href="#data_format">instruction</a>.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-1-1.png?raw=true)

#### Step 2

After clicking the "Run" button, additional input boxes and upload button will appear.

For a single sequence, enter the sequence and the mutation information into the corresponding input fields. (Note that for Saturation mutagenesis, you won't see the Mutation input box.)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-2-1.png?raw=true)

For multiple sequences, click the upload button to upload your dataset. (Note that for Saturation mutagenesis, you don’t need to provide mutation information in your dataset, which means only `sequence` column is required in the .csv dataset.)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-2-2.png?raw=true)

#### Step 3

Click the run button to start predicting.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-1.png?raw=true)

Check your results after finishing prediction.

For a single sequence, the predicted score will be show in the output.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-2.png?raw=true)

For multiple sequence, the predicted score will be saved in a .csv file.

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-3.png?raw=true)

![Untitled](https://github.com/westlake-repl/SaProtHub/blob/main/Figure/Instruction/mutation-3-4.png?raw=true)