# Fine-tuning ESM-2

In this notebook, we'll be fine-tuning ESM-2 to predict the subcelluar location of proteins based on the input amino acid sequence. We'll start by showing how this can be done with the published model [hosted on HuggingFace](https://huggingface.co/facebook/esm2_t33_650M_UR50D) and then showing how this can be done using NVIDIA's [BioNeMo 2 Framework](https://docs.nvidia.com/bionemo-framework/latest/user-guide/)

> Inspired by ESM-2's [example notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_language_modeling.ipynb) for fine-tuning.

This notebook needs the [evaluate](https://pypi.org/project/evaluate/) package, but it's not present in the BioNeMo v2.3 container, so we'll need to manually install it with `pip`

> Make sure to restart the kernel of this notebook after installing

In [1]:
!pip install evaluate
# Restart kernel after installing

Defaulting to user installation because normal site-packages is not writeable
[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/dill-0.3.9-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/opt_einsum-3.4.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPREC

In [None]:
# Import necessary packages
import requests, pandas, os, evaluate
# Rerun previous cell and restart kernel if this fails
from io import BytesIO

# Set environment variables for huggingface
for var in ['HF_HOME','HF_HUB_CACHE']:
    os.environ[var] = '/tmp/hf'

We're going to fine-tune ESM2-650M using human protein sequences (`organism_id:9606`) that we reviewed (`reviewed:true`), and range from 80 to 500 amino acids in length (`length:[80 TO 500]`) and only outputting the `Sequence` and `Subcellular location [CC]` columns. [UniProt](https://www.uniprot.org/) actually has a REST API, so this query has been encoded into `query_url`.

This download can sometimes fail or take a while, so we cache the data in parquet format once it succeeds. This will make any repeated runs of the notebook go much faster on the same compute node.

In [2]:
query_url = "https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B80%20TO%20500%5D%29%29"
tmp_file = "/tmp/uniprot.parquet.gz"

# Logic to quickly load data if cached
if not os.path.exists(tmp_file):
    # Download data
    uniprot_request = requests.get(query_url)
    # Store data as binary object that works like a file
    bio = BytesIO(uniprot_request.content)
    # Read binary object as compressed csv
    df = pandas.read_csv(bio, compression='gzip', sep='\t')
    # Cache to local location for faster reloads
    df.to_parquet(tmp_file, compression="gzip")
else:
    # Load from cache
    df = pandas.read_parquet(tmp_file)
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
1,A0AVI4,MDSPEVTFTLAYLVFAVCFVFTPNEFHAAGLTVQNLLSGWLGSEDA...,SUBCELLULAR LOCATION: Endoplasmic reticulum me...
2,A0JLT2,MENFTALFGAQADPPPPPTALGFGPGKPPPPPPPPAGGGPGTAPPP...,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
4,A0PJY2,MDSSCHNATTKMLATAPARGNMMSTSKPLAFSIERIMARTPEPKAL...,SUBCELLULAR LOCATION: Nucleus {ECO:0000269|Pub...
...,...,...,...
11972,Q9H8W2,MRPGSSPRAPECGAPALPRPQLDRLPARPAPSRGRGAPSLRWPAKE...,
11973,Q9HAA7,MLFGIRILVNTPSPLVTGLHHYNPSIHRDQGECANQWRKGPGSAHL...,
11974,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,
11975,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,


Our goal is to train a model that can predict proteins that are located in the [cytosol](https://en.wikipedia.org/wiki/Cytosol) (intracellular fluid) or the membrane of a cell. Proteins in the cytosol have a `Subcellular location [CC]` of `Cytoplasm` or `Cytosol`. Proteins in the cell membrane have a `Subcellular location [CC]` of `Membrane` or `Cell membrane`.

Once these are selected, we can create new dataframes for each type, while also excluding proteins that exist in both.

In [3]:
# Drop proteins with missing columns
df = df.dropna()
# Get ids of proteins with Cytosol or Cytoplasm locations
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
# Get ids of proteins with Membrane or Cell membrane locations
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")

# Create new cytosolic dataframe with proteins
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
9,A1E959,MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
14,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
18,A2RU49,MSSGNYQQSEALSKPTFSEEQASALVESVFGLKVSKVRPLPSYDDQ...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
20,A2RUH7,MEAATAPEVAAGSKLKVKEASPADAEPPQASPGQGAGSPTPQLLPP...,"SUBCELLULAR LOCATION: Cytoplasm, myofibril, sa..."
21,A4D126,MEAGPPGSARPAEPGPCLSGQRGADHTASASLQSVAGTEPGRHPQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytosol {ECO:..."
...,...,...,...
11495,Q8NBC4,MFPRPVLNSRAQAILLPQPPNMLDHRQWPPRLASFPFTKTGMLSRA...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
11535,Q8TDY3,MFNPHALDSPAVIFDNGSGFCKAGLSGEFGPRHMVSSIVGHLKFQA...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ..."
11547,Q8WWF8,MAGTARHDREMAIQAKKKLTTATDPIERLRLQCLARGSAGIKGLGR...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
11671,Q9NUJ7,MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...


In [4]:
# Create new membrane dataframe with proteins
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
3,A0M8Q6,GQPKAAPSVTLFPPSSEELQANKATLVCLVSDFNPGAVTVAWKADG...,SUBCELLULAR LOCATION: Secreted {ECO:0000303|Pu...
17,A2RU14,MAGTVLGVGAGVFILALLWVAVLLLCVLLSRASGAARFSVIFLFFG...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
33,A5X5Y0,MEGSWFHRKRFSFYLLLGFLLQGRGVTFTINCSGFGQHGADPTALN...,SUBCELLULAR LOCATION: Postsynaptic cell membra...
36,A6ND01,MACWWPLLLELWTVMPTWAGDELLNICMNAKHHKRVPSPEDKLYEE...,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...
11895,Q86UQ5,MQSDIYHPGHSFPSWVLCWVHSCGHEGHLRETAEIRKTHQNGDLQI...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11918,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11955,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
11963,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


Now that we've filtered and separated out he proteins of interest, we can extract the sequences, and encode the locations as `0` for cytosolic proteins and `1` for membrane proteins. Then, we can combine the types together into two lists: sequences and labels.

In [5]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]
membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

# Quick check to make sure we got it right
assert(len(sequences) == len(labels))

When training a model, you need training data for the model to learn from and validation data that the model never learns from to make sure what it learns is generally applicable. We're going to use sklearn to split our lists 75% and 25% into `train` and `test` datasets. Once that is done, we'll cache the datasets to CSV (comma separate values) format so both fine-tuning methods use the same data.

BioNeMo also only validates for two steps while training, so we're going to truncate our test data to 64 items (8 batches of 8).

In [6]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

pandas.DataFrame({"sequences":train_sequences, "labels":train_labels}).to_csv("/tmp/train_df.csv")
pandas.DataFrame({"sequences":test_sequences, "labels":test_labels}).to_csv("/tmp/test_df.csv")

# Only keep the first 16 validation sequences for parity with BioNeMo
test_sequences_small = test_sequences[:64]
test_labels_small = test_labels[:64]

Inputs sequences need to be tokenized into numerical format for the model. When we pull down the ESM2-650M model checkpoint from HuggingFace, we get the tokenizer too.

In [7]:
from transformers import AutoTokenizer

# Using the ESM2 650M model checkpoint from HuggingFace
model_checkpoint = "facebook/esm2_t33_650M_UR50D"

# Load the tokenizer from the model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

  from .autonotebook import tqdm as notebook_tqdm


To understand what happens during tokenization, we can tokenize the first sequence.

You'll notice that the tokenized sequence is 2 values longer than the original sequence. That's because there are tokens to represent the START and END of the sequence for the model.

In [8]:
# Tokenize the first sequence for demonstration
seq = train_sequence[0]
tokenized = tokenizer(seq)

print(f"Sequence length: {len(seq)}\nToken length: {len(tokenized['input_ids'])}")

tokenized

{'input_ids': [0, 20, 7, 5, 11, 23, 4, 16, 7, 7, 6, 18, 7, 11, 8, 18, 7, 6, 22, 12, 6, 7, 12, 7, 11, 11, 8, 11, 17, 13, 22, 7, 7, 11, 23, 6, 19, 11, 12, 14, 11, 23, 10, 15, 4, 13, 9, 4, 6, 8, 15, 6, 4, 22, 5, 13, 23, 7, 20, 5, 11, 6, 4, 19, 21, 23, 15, 14, 4, 7, 13, 12, 4, 12, 4, 14, 6, 19, 7, 16, 5, 23, 10, 5, 4, 20, 12, 5, 5, 8, 7, 4, 6, 4, 14, 5, 12, 4, 4, 4, 4, 11, 7, 4, 14, 23, 12, 10, 20, 6, 16, 9, 14, 6, 7, 5, 15, 19, 10, 10, 5, 16, 4, 5, 6, 7, 4, 4, 12, 4, 4, 5, 4, 23, 5, 4, 7, 5, 11, 12, 22, 18, 14, 7, 23, 5, 21, 10, 9, 11, 11, 12, 7, 8, 18, 6, 19, 8, 4, 19, 5, 6, 22, 12, 6, 5, 7, 4, 23, 4, 7, 6, 6, 23, 7, 12, 4, 23, 23, 5, 6, 13, 5, 16, 5, 18, 6, 9, 17, 10, 18, 19, 19, 11, 5, 6, 8, 8, 8, 14, 11, 21, 5, 15, 8, 5, 21, 7, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [9]:
# Tokenize all sequences
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences_small)

After tokenizing all of our sequences, we can create a `Dataset` object that will handle data shuffling and iterating while training.

In [10]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

# Add labels to Dataset
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels_small)

# Print the shape and columns of the datasets after adding labels
print(train_dataset.shape, list(train_dataset[0].keys()))
print(test_dataset.shape, list(test_dataset[0].keys()))

(3886, 3) ['input_ids', 'attention_mask', 'labels']
(16, 3) ['input_ids', 'attention_mask', 'labels']


  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


Now that data preparation is done, we can pull the model and configure it. We're going to be training the model to classify an input sequence as one of two labels {0: cytosolic, 1: membrane}, so we're going to load the ESM-2 650M checkpoint using the `AutoModelForSequenceClassification` class. Notice that we're also telling it that there are two possible labels.

In [11]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1  # 2: {0: cytosolic, 1: membrane}
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


After loading the model, we can configure the training run. To mimic the way BioNeMo does fine-tuning, we're going to train for 200 steps and evaluate every 50 steps. On a 48GB GPU like the L40S, a batch size of 8 can be used. For other GPUs, experiment with other values based on memory usage.

In [12]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8 # Works for 48GB GPU
strat = "steps" # "epoch"

args = TrainingArguments(
    f"/tmp/{model_name}-finetuned-localization", # Make sure to change this for a real model
    eval_strategy = strat,
    save_strategy = strat,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    max_steps=200,
    eval_steps=50,
    #num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    report_to="none",
    include_tokens_per_second=True,
)

We're also going to use HuggingFace's evaluate package to compute the accuracy of the classifications.

In [13]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

Up next, we create the `Trainer` class using the model, data, and configurations. If everything is valid, we can start training.

In [14]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [15]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy
50,No log,0.177743,0.9375
100,No log,0.032178,1.0
150,No log,0.018176,1.0
200,No log,0.020032,1.0


Could not locate the best model at /tmp/esm2_t33_650M_UR50D-finetuned-localization/checkpoint-100/pytorch_model.bin, if you are running a distributed training on multiple nodes, you should activate `--save_on_each_node`.


TrainOutput(global_step=200, training_loss=0.2943057441711426, metrics={'train_runtime': 111.3586, 'train_samples_per_second': 14.368, 'train_steps_per_second': 1.796, 'train_tokens_per_second': 6867.903, 'total_flos': 2792200504173504.0, 'train_loss': 0.2943057441711426, 'epoch': 0.411522633744856})

Once training is done, stop this notebook so the model is freed from memory and we can do fine-tuning with BioNeMo