# Fine-tuning AntiBERTy to Predict S Protein Binding

This Colab notebook shows how to fine-tune the **AntiBERTy** antibody language model to predict whether an antibody sequence binds the SARS-CoV-2 Spike (S) protein.

It is adapted from a script to perform supervised fine-tuning of pre-trained antibody language models for antigen specificity prediction as published by [Wang. et al. 2025](https://doi.org/10.1371/journal.pcbi.1012153) that:
- Loads antibody sequences and S-binding labels
- Formats sequences for AntiBERTy
- Uses a grouped, stratified cross-validation split
- Fine-tunes AntiBERTy with Hugging Face `Trainer`
- Evaluates with AUC, MCC, balanced accuracy, etc.

---

## How to use this notebook

1. Download a parquet file:
   - e.g. `S_CDR3.parquet` from [figshare](https://figshare.com/articles/dataset/Fine-tuning_Pre-trained_Antibody_Language_Models_for_Antigen_Specificity_Prediction/25342924)

2. Update `DATA_DIR` and `OUTPUT_DIR` below if needed.

---

> **Dataset assumptions**
> - Columns:
>   - `HL` or `H`: sequences (heavy + light vs heavy only)
>   - `label`: includes antigen binding values like `"S+"`, `"S1+"`, `"S2+"` (positive) and others (negative)
>   - `subject`: donor / study ID for grouped cross-validation.


## Install dependencies (run once per session)

In [1]:
#!pip install -q antiberty transformers datasets scikit-learn biopython pyarrow

## Imports

In [2]:
import os
import random
from collections import Counter

import numpy as np
import pandas as pd
import torch

from sklearn.metrics import (
    precision_score, recall_score, f1_score,
    matthews_corrcoef, roc_auc_score,
    average_precision_score, balanced_accuracy_score
)

from sklearn.model_selection import StratifiedGroupKFold
from datasets import Dataset, DatasetDict, ClassLabel
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

import antiberty

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


PyTorch: 2.3.0
CUDA available: True


device(type='cuda')

# Configuration

In [3]:
# Which column in the parquet to use as sequences
MODEL_TYPE = "HL"  

# Which dataset variant to use
SEQUENCE_SCOPE = "CDR3"  

# Path to your data directory
DATA_DIR = "../data/" 

# Path where models and logs will be saved
OUTPUT_DIR = "../models/"

# Training hyperparameters
BATCH_SIZE = 64        
LR = 1e-5              
N_EPOCHS = 10          

RANDOM_STATE_OUTER = 7 if SEQUENCE_SCOPE == "CDR3" else 9
RANDOM_STATE_INNER = 1

RUN_ID = f"S_antiBERTy_{MODEL_TYPE}_fine_tuning_{SEQUENCE_SCOPE}"
print("Run ID:", RUN_ID)
print("Data dir:", DATA_DIR)
print("Output dir:", OUTPUT_DIR)

Run ID: S_antiBERTy_HL_fine_tuning_CDR3
Data dir: ../data/
Output dir: ../models/


In [4]:
os.path.dirname(os.path.realpath(antiberty.__file__))

'/vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty'

# Helper functions: model loading, freezing, formatting, metrics

In [5]:
def get_antiberty_paths():
    """Locate AntiBERTy model + vocab from the antiberty package."""
    project_path = os.path.dirname(os.path.realpath(antiberty.__file__))
    trained_dir = os.path.join(project_path, "trained_models")
    model_dir = os.path.join(trained_dir, "AntiBERTy_md_smooth")
    vocab = os.path.join(trained_dir, "vocab.txt")
    print("AntiBERTy model:", model_dir)
    print("AntiBERTy vocab:", vocab)
    return model_dir, vocab


def load_antiberty_classifier(num_labels: int = 2):
    """Load AntiBERTy as a sequence-classification model + tokenizer."""
    model_dir, vocab = get_antiberty_paths()
    tokenizer = transformers.BertTokenizer(
        vocab_file=vocab,
        do_lower_case=False
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir,
        num_labels=num_labels
    )
    model.to(device)
    size = sum(p.numel() for p in model.parameters())
    print(f"Model size: {size/1e6:.2f}M parameters")
    return model, tokenizer


def freeze_antiberty_layers(model, train_last_n_layers: int = 3):
    """Freeze embeddings and early encoder layers of AntiBERTy."""
    for p in model.bert.embeddings.parameters():
        p.requires_grad = False

    total_layers = len(model.bert.encoder.layer)  # AntiBERTy has 8 layers
    for layer in model.bert.encoder.layer[: total_layers - train_last_n_layers]:
        for p in layer.parameters():
            p.requires_grad = False
    return model


def insert_space_every_other_except_cls(s: str) -> str:
    """Add spaces between residues, keeping [CLS] intact."""
    parts = s.split("[CLS]")
    spaced = [" ".join(list(part)) for part in parts]
    out = " [CLS] ".join(spaced)
    return " ".join(out.split())


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def compute_metrics(eval_pred):
    """Metrics callback for Hugging Face Trainer."""
    logits, labels = eval_pred
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    preds = np.argmax(logits, axis=1)

    return {
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds),
        "f1_weighted": f1_score(labels, preds, average="weighted"),
        "apr": average_precision_score(labels, probs),
        "balanced_accuracy": balanced_accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "mcc": matthews_corrcoef(labels, preds),
    }

# Load dataset

In [6]:
MAX_LENGTH = 512 - 2  # AntiBERTy max length minus specials

def load_data(scope: str = "CDR3", model_type: str = "HL"):
    if scope == "FULL":
        filename = "S_FULL.parquet"
        print("Loading full-length sequences...")
    else:
        filename = "S_CDR3.parquet"
        print("Loading CDR3 sequences...")

    path = os.path.join(DATA_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Could not find {path}. Please check DATA_DIR and filename.")

    df = pd.read_parquet(path)

    X = df[model_type].apply(lambda s: s[:MAX_LENGTH])
    X = X.str.replace("<cls><cls>", "[CLS][CLS]", regex=False)
    X = X.apply(insert_space_every_other_except_cls)

    y = np.isin(df["label"], ["S+", "S1+", "S2+"]).astype(int)
    groups = df["subject"].values

    print(f"Total sequences: {len(X)}")
    print(f"Unique donors: {len(np.unique(groups))}")
    print("Label counts:", Counter(y))

    return X, y, groups, df


X, y, y_groups, raw_df = load_data(SEQUENCE_SCOPE, MODEL_TYPE)

Loading CDR3 sequences...
Total sequences: 15539
Unique donors: 427
Label counts: Counter({1: 8658, 0: 6881})


# Create train/val/test splits (StratifiedGroupKFold)

In [7]:
outer_cv = StratifiedGroupKFold(
    n_splits=4,
    shuffle=True,
    random_state=RANDOM_STATE_OUTER
)

inner_cv = StratifiedGroupKFold(
    n_splits=3,
    shuffle=True,
    random_state=RANDOM_STATE_INNER
)

# Use the first outer fold for this tutorial
for fold_idx, (train_index, test_index) in enumerate(outer_cv.split(X, y, y_groups), start=1):
    print(f"Using outer fold {fold_idx}")
    X_train_all, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train_all, y_test = y[train_index], y[test_index]
    y_groups_train = y_groups[train_index]
    break

print(f"Outer train size: {len(X_train_all)}, test size: {len(X_test)}")
print(f"% positive train: {np.mean(y_train_all):.3f}, % positive test: {np.mean(y_test):.3f}")

# Inner split to create validation set
for inner_idx, (inner_train_index, val_index) in enumerate(
    inner_cv.split(X_train_all, y_train_all, y_groups_train),
    start=1
):
    print(f"Using inner fold {inner_idx} for train/val split")
    X_train = X_train_all.iloc[inner_train_index]
    y_train = y_train_all[inner_train_index]
    X_val = X_train_all.iloc[val_index]
    y_val = y_train_all[val_index]
    break

print(f"Final sizes â€” train: {len(X_train)}, val: {len(X_val)}, test: {len(X_test)}")

Using outer fold 1
Outer train size: 12969, test size: 2570
% positive train: 0.555, % positive test: 0.566
Using inner fold 1 for train/val split
Final sizes â€” train: 8423, val: 4546, test: 2570


# Build Hugging Face Datasets & tokenize

In [8]:
train_df = pd.DataFrame({"sequence": X_train.values, "labels": y_train})
val_df   = pd.DataFrame({"sequence": X_val.values,   "labels": y_val})
test_df  = pd.DataFrame({"sequence": X_test.values,  "labels": y_test})

raw_datasets = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

model, tokenizer = load_antiberty_classifier(num_labels=2)
model = freeze_antiberty_layers(model, train_last_n_layers=3)

def preprocess_function(batch):
    encodings = tokenizer(
        batch["sequence"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )
    encodings["labels"] = batch["labels"]
    return encodings

tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=["sequence"],
)

AntiBERTy model: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth
AntiBERTy vocab: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/vocab.txt


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model size: 25.76M parameters


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8423/8423 [00:02<00:00, 3713.91 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4546/4546 [00:01<00:00, 4072.29 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2570/2570 [00:00<00:00, 3655.37 examples/s]


# Training setup (Trainer)

In [9]:
set_seed(1)

FOLD_ID = 1
OUT_PATH = os.path.join(OUTPUT_DIR, f"{RUN_ID}_Fold_{FOLD_ID}")
os.makedirs(OUT_PATH, exist_ok=True)
print("Saving checkpoints to:", OUT_PATH)

training_args = TrainingArguments(
    output_dir=OUT_PATH,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=N_EPOCHS,
    warmup_ratio=0.0,
    load_best_model_at_end=True,
    metric_for_best_model="auc",
    lr_scheduler_type="linear",
    seed=1
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Saving checkpoints to: ../models/S_antiBERTy_HL_fine_tuning_CDR3_Fold_1


## Train AntiBERTy (this can take a while)

In [10]:
train_result = trainer.train()
train_result.metrics

Epoch,Training Loss,Validation Loss,Precision,Recall,F1 Weighted,Apr,Balanced Accuracy,Auc,Mcc
1,0.686,0.672582,0.578891,0.859656,0.531055,0.643557,0.547432,0.602929,0.122046
2,0.6636,0.663214,0.6018,0.77529,0.576286,0.66619,0.573953,0.625232,0.161951
3,0.6515,0.656686,0.629101,0.705318,0.604632,0.679539,0.59838,0.640495,0.201339
4,0.6363,0.65439,0.635659,0.688525,0.608874,0.689057,0.602942,0.647817,0.208721
5,0.6292,0.652828,0.638046,0.689324,0.611427,0.693718,0.605542,0.652462,0.213864
6,0.6197,0.656932,0.674221,0.570972,0.612711,0.698114,0.616782,0.656764,0.232928
7,0.6159,0.651222,0.647674,0.668133,0.616673,0.702476,0.611817,0.660182,0.224564
8,0.6118,0.651564,0.652422,0.640944,0.614947,0.704302,0.61167,0.661166,0.222945
9,0.6076,0.651861,0.657959,0.644542,0.620488,0.705702,0.617381,0.662358,0.23429
10,0.6066,0.651196,0.648607,0.670132,0.617949,0.706074,0.613061,0.662967,0.227118


{'train_runtime': 626.7467,
 'train_samples_per_second': 134.392,
 'train_steps_per_second': 2.106,
 'total_flos': 6568285780076400.0,
 'train_loss': 0.6328125693581321,
 'epoch': 10.0}

# Evaluate on held-out test set

In [11]:
model.eval()
test_outputs = trainer.predict(tokenized_datasets["test"])
test_metrics = test_outputs.metrics

print("Test metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v:.4f}")

Test metrics:
test_loss: 0.6747
test_precision: 0.6530
test_recall: 0.6224
test_f1_weighted: 0.6003
test_apr: 0.6858
test_balanced_accuracy: 0.5957
test_auc: 0.6316
test_mcc: 0.1903
test_runtime: 8.7614
test_samples_per_second: 293.3340
test_steps_per_second: 4.6800
