<a name="setup"></a>
# 🛠️ Environment Setup
---

In [1]:
!pip install transformers==4.39.3
!pip install -q -U peft==0.10.0
!pip install -q -U accelerate==0.29.2
!pip install -q -U datasets==2.18.0
!pip install -q -U evaluate==0.4.1



In [None]:
# The latest version of transformers was not working properly, so I had to use an earlier one
import transformers
transformers.__version__

'4.39.3'

In [None]:
# import libraries
import random
import os
import torch

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    roc_auc_score,
    average_precision_score,
)

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    EsmForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)

from peft import get_peft_model, LoraConfig, PeftModel
from datasets import Dataset, DatasetDict
from evaluate import load

import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import files
files.upload()

<a name="data"></a>
# 🗃️ Load and Preprocess Data

In [None]:
# Read in data
df = pd.read_csv("Dataset.csv")
print(len(df))
df.head()
print(df['Class'].value_counts())

9974


Unnamed: 0,Protein_ID,Sequence,Class
0,113927,MALSLFTVGQLIFLFWTLRITEANPDPAAKAAPAAVADPAAAAAAA...,AFP
1,210960,MKSAILTGLLFVLLCVDHMSSASQQSVVATQLIPINTALTPIMMKG...,AFP
2,213510,MLAALLVCAMVALTRAANGDTGKEAVMTGSSGKNLTECPTDWKMFN...,AFP
3,2315605,MRRQTTAIFVLLGLLAVFVVQGSTEDTGSTPTADNAPAASNGTAAP...,AFP
4,2411496,MSFKISTFTKIWLIIAVIVMCLCNEYNCQCTGAADCTSCTAACTGC...,AFP


In [5]:
# create a new column 'labels' to contain binary labels
df['labels'] = df['Class'].apply(lambda x: 0 if x == 'NON-AFP' else 1)
df['labels'].value_counts(normalize=True)

Unnamed: 0_level_0,proportion
labels,Unnamed: 1_level_1
0,0.951775
1,0.048225


We have highly imbalanced classes

In [6]:
# train, validation split, splits unbalanced classes properly
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['labels'],
    random_state=42)

print(len(train_df))
print(len(val_df))

7979
1995


In [7]:
# convert data into HuggingFace DatasetDict
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

In [None]:
# load tokenizer and tokenize data
# 35M seems to be the largest model that a T4 on Colab can handle
# model_checkpoint = "facebook/esm2_t30_150M_UR50D"
model_checkpoint = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize(examples, max_length=1023):
    text = examples["Sequence"]
    encoding = tokenizer(text, truncation=True, max_length=max_length)
    encoding["labels"] = examples["labels"]
    return encoding

encoded_dataset = dataset_dict.map(
    tokenize,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=dataset_dict["train"].column_names
)

encoded_dataset.set_format("torch")



tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Map (num_proc=2):   0%|          | 0/7979 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/1995 [00:00<?, ? examples/s]

<a name="train"></a>
# 💪 Train Model
---

In [54]:
# load model checkpoint for classification
model = EsmForSequenceClassification.from_pretrained(
	model_checkpoint,
	num_labels=2
)



config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/595M [00:00<?, ?B/s]

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# configure model for LoRA fine-tuning
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    inference_mode=False,
    bias="none",
    r=16, # rank number
    lora_alpha=16, # scaling factor
    lora_dropout=0.2, # dropout prob
    target_modules=[ # which layers to apply LoRA
        "query",
        "key",
        "value"
    ],
    modules_to_save=['classifier'] # ensures that the fine-tuned classifier head is saved when calling trainer.save_model later
)

model = get_peft_model(model, peft_config)

# adjust dropout in the classifier head
model.base_model.model.classifier.modules_to_save.default.dropout.p = 0.2

In [56]:
# show amount of trainable parameters
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(model)

trainable params: 2254722 || all params: 151051485 || trainable%: 1.492684431404299


In [None]:
# configure training args
num_train_epochs = 5
batch_size = 8
learning_rate = 1e-3
eval_howoften = 250

args = TrainingArguments(
    seed=42,
    fp16=True,
    output_dir='./results',
    evaluation_strategy = "steps",
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    # gradient_checkpointing=True,
    logging_steps=eval_howoften,
    eval_steps=eval_howoften,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    metric_for_best_model="auc_roc",
    load_best_model_at_end=True,
    report_to='none'  # Disable Weights & Biases logging
)

# define metrics to compute during training
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    softmax = torch.nn.Softmax(dim=1)
    probabilities = softmax(torch.tensor(logits)).numpy()
    predictions = np.argmax(probabilities, axis=1)
    probabilities_pos_class = probabilities[:, 1]

    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, zero_division=0)
    recall = recall_score(labels, predictions, zero_division=0)
    auc = roc_auc_score(labels, probabilities_pos_class)
    auc_pr = average_precision_score(labels, probabilities_pos_class)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "auc_roc": auc,
        "auc_pr": auc_pr,

    }

# define early stopping
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

# Custom Trainer for unbalanced classes
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")  # Extract labels
        outputs = model(**inputs)  # Forward pass
        logits = outputs.logits  # Get logits

        # Balanced class weights for 95% / 5% class distribution
        class_weights = torch.tensor([0.05, 0.95], device=logits.device)  # Adjusted for better training stability
        loss_fn = nn.CrossEntropyLoss(weight=class_weights)  # Weighted CE Loss

        loss = loss_fn(logits, labels)  # Compute loss
        return (loss, outputs) if return_outputs else loss


In [72]:
trainer = CustomTrainer(
    model=model,
    args=args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]
    )

# train model
trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,Auc Roc,Auc Pr
250,0.3697,0.310349,0.97594,0.928571,0.541667,0.880762,0.658682
500,0.2298,0.336471,0.976441,1.0,0.510417,0.905638,0.702569
750,0.1561,0.380596,0.978446,0.964912,0.572917,0.91617,0.743477
1000,0.1811,0.608643,0.959398,0.570093,0.635417,0.924689,0.657209
1250,0.1867,0.295661,0.979449,0.966102,0.59375,0.922731,0.768008
1500,0.301,0.952528,0.952381,1.0,0.010417,0.680857,0.112609
1750,0.3436,0.590952,0.976441,0.980392,0.520833,0.916812,0.700222




TrainOutput(global_step=1750, training_loss=0.2525857870919364, metrics={'train_runtime': 2228.5812, 'train_samples_per_second': 17.902, 'train_steps_per_second': 4.476, 'total_flos': 3314519000108100.0, 'train_loss': 0.2525857870919364, 'epoch': 0.88})

In [14]:
# evaluate model on validation set
eval_dict = trainer.evaluate()
eval_dict

{'eval_loss': 0.08110496401786804,
 'eval_accuracy': 0.9769423558897243,
 'eval_precision': 0.8787878787878788,
 'eval_recall': 0.6041666666666666,
 'eval_auc_roc': 0.9408296032999824,
 'eval_auc_pr': 0.797464720524441,
 'eval_runtime': 38.8891,
 'eval_samples_per_second': 51.3,
 'eval_steps_per_second': 3.214,
 'epoch': 3.21}

In [None]:
# save fine-tuned LoRA adapters + classification head
model_path = "model"
trainer.save_model(model_path)

<a name="inference"></a>
## 🎯 Inference
---

In [None]:
# load fine-tuned model adapters onto the base model checkpoint
base_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
fine_tuned_model = PeftModel.from_pretrained(base_model, model_path)

In [None]:
# generate predictions
def predict(text, model, tokenizer):
    inputs = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        max_length=1023
    )

    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probabilities = torch.nn.functional.softmax(logits, dim=-1)

    return probabilities[:, 1].numpy()

val_pred_probas = val_df['peptide'].apply(lambda x: predict(x, fine_tuned_model, tokenizer))

print('ROC-AUC:', roc_auc_score(val_df['labels'], val_pred_probas))
print('PR-AUC:', average_precision_score(val_df['labels'], val_pred_probas))