<a name="setup"></a>
# 🛠️ Environment Setup
---

In [None]:
# import libraries
import random
import os
import torch

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    roc_auc_score,
    average_precision_score
)

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    EsmForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    get_scheduler
)

from peft import get_peft_model, LoraConfig, PeftModel
from datasets import Dataset, DatasetDict
from evaluate import load

import torch.nn as nn
import torch.nn.functional as F

<a name="data"></a>
# 🗃️ Load and Preprocess Data

In [None]:
# read data
df = pd.read_csv("Dataset.csv")
print(len(df))
df.head()

9974


Unnamed: 0,Protein_ID,Sequence,Class
0,113927,MALSLFTVGQLIFLFWTLRITEANPDPAAKAAPAAVADPAAAAAAA...,AFP
1,210960,MKSAILTGLLFVLLCVDHMSSASQQSVVATQLIPINTALTPIMMKG...,AFP
2,213510,MLAALLVCAMVALTRAANGDTGKEAVMTGSSGKNLTECPTDWKMFN...,AFP
3,2315605,MRRQTTAIFVLLGLLAVFVVQGSTEDTGSTPTADNAPAASNGTAAP...,AFP
4,2411496,MSFKISTFTKIWLIIAVIVMCLCNEYNCQCTGAADCTSCTAACTGC...,AFP


In [None]:
df['Class'].value_counts()

Class
NON-AFP    9493
AFP         481
Name: count, dtype: int64

In [5]:
# create a new column 'labels' to contain binary labels
df['labels'] = df['Class'].apply(lambda x: 0 if x == 'NON-AFP' else 1)
df['labels'].value_counts(normalize=True)

labels
0    0.951775
1    0.048225
Name: proportion, dtype: float64

Imbalanced classes 95% to 5% ish

In [8]:
# train, validation split, splits unbalanced classes properly
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['labels'],
    random_state=42)

print(len(train_df))
print(len(val_df))

7979
1995


In [9]:
# convert data into HuggingFace DatasetDict
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

In [11]:
# load tokenizer and tokenize data
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize(examples, max_length=1023):
    text = examples["Sequence"]
    encoding = tokenizer(text, truncation=True, max_length=max_length)
    encoding["labels"] = examples["labels"]
    return encoding

encoded_dataset = dataset_dict.map(
    tokenize,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=dataset_dict["train"].column_names
)

encoded_dataset.set_format("torch")

Map (num_proc=8): 100%|██████████| 7979/7979 [00:04<00:00, 1823.39 examples/s]
  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=8): 100%|██████████| 1995/1995 [00:01<00:00, 1958.20 examples/s]


<a name="train"></a>
# 💪 Train Model
---

In [12]:
# load model checkpoint for classification
model = EsmForSequenceClassification.from_pretrained(
	model_checkpoint,
	num_labels=2
)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# configure model for LoRA fine-tuning
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    inference_mode=False,
    bias="none",
    r=8, # rank number
    lora_alpha=16, # scaling factor)
    lora_dropout=0.2, # dropout prob
    target_modules=[ # which layers to apply LoRA
        "query",
        "key",
        "value"
    ],
    modules_to_save=['classifier'] # ensures that the fine-tuned classifier head is saved when calling trainer.save_model later
)

model = get_peft_model(model, peft_config)

# adjust dropout in the classifier head
model.base_model.model.classifier.modules_to_save.default.dropout.p = 0.25

In [14]:
# show amount of trainable parameters
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(model)

trainable params: 195522 || all params: 8036285 || trainable%: 2.4329898703194326


Only 195,522 parameters out of 8,036,285 (2.43%) are adjusted during training.

In [None]:
# configure training args
num_train_epochs = 10
batch_size = 16
learning_rate = 1e-3

args = TrainingArguments(
    seed=42,
    fp16=True,
    output_dir='./results',
    evaluation_strategy = "steps",
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=4,
    # gradient_checkpointing=True,
    logging_steps=50,
    eval_steps=50,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    metric_for_best_model="auc_roc",
    load_best_model_at_end=True,
    report_to='none'  # Disable Weights & Biases logging
)

# define metrics to compute during training
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    softmax = torch.nn.Softmax(dim=1)
    probabilities = softmax(torch.tensor(logits)).numpy()
    predictions = np.argmax(probabilities, axis=1)
    probabilities_pos_class = probabilities[:, 1]

    accuracy = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, zero_division=0)
    recall = recall_score(labels, predictions, zero_division=0)
    auc = roc_auc_score(labels, probabilities_pos_class)
    auc_pr = average_precision_score(labels, probabilities_pos_class)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "auc_roc": auc,
        "auc_pr": auc_pr
    }

# define early stopping
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

class FocalLoss(nn.Module):
	def __init__(self, alpha=0.25, gamma=2.0):
		super().__init__()
		self.alpha = alpha  # Alpha for the positive class (minority class)
		self.gamma = gamma

	def forward(self, inputs, targets):
		# Compute binary cross-entropy loss
		BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
		
		# Compute pt
		pt = torch.exp(-BCE_loss)  # pt = p if target == 1, else 1-p
		
		# Apply alpha based on the target class
		alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)  # alpha_t = alpha if target == 1, else 1-alpha
		focal_loss = alpha_t * (1 - pt) ** self.gamma * BCE_loss
		
		return focal_loss.mean()

# Initialize Focal Loss
loss_fn = FocalLoss(alpha=0.95, gamma=2.0)

# Pass the loss function to the Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=encoded_dataset['train'],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback],
    loss_fn=loss_fn
)

# train model
trainer.train()

  trainer = Trainer(


ValueError: fp16 mixed precision requires a GPU (not 'mps').

In [None]:
# evaluate model on validation set
eval_dict = trainer.evaluate()
eval_dict

In [None]:
# save fine-tuned LoRA adapters + classification head
model_path = "/content/demo_model"
trainer.save_model(model_path)

<a name="inference"></a>
## 🎯 Inference Demo
---

In [None]:
# load fine-tuned model adapters onto the base model checkpoint
base_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
fine_tuned_model = PeftModel.from_pretrained(base_model, model_path)

In [None]:
# generate predictions
def predict(text, model, tokenizer):
    inputs = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        max_length=1023
    )
    
    model.eval()
    
    with torch.no_grad():
        logits = model(**inputs).logits

    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    
    return probabilities[:, 1].numpy()

val_pred_probas = val_df['peptide'].apply(lambda x: predict(x, fine_tuned_model, tokenizer))

print('ROC-AUC:', roc_auc_score(val_df['labels'], val_pred_probas))
print('PR-AUC:', average_precision_score(val_df['labels'], val_pred_probas))