# Embedding with a fine-tuned custom model using AMULETY

In this tutorial we will showcase how to fine-tune the **AntiBERTy** antibody language model to predict binding to the SARS-CoV-2 Spike protein (S), as published in [Wang. et al. 2025](https://doi.org/10.1371/journal.pcbi.1012153). 



The tutorial goes through the following steps:
- Loading antibody sequences and S-binding labels
- Formating sequences for AntiBERTy
- Using a grouped, stratified cross-validation split
- Fine-tuning AntiBERTy with Hugging Face `Trainer`
- Evaluating with AUC, MCC, balanced accuracy, etc.
- Using AMULETY to embed new sequences with the fine-tuned model.


## Set-up

We recommend running this tutorial on a GPU with at least 25 GB of RAM. On such a setup, the full notebook typically completes in about 15 minutes.

First, download a parquet file containing the training data:
   - `S_CDR3.parquet` from [figshare](https://figshare.com/articles/dataset/Fine-tuning_Pre-trained_Antibody_Language_Models_for_Antigen_Specificity_Prediction/25342924)

The dataset contains the following columns:
- `HL` or `H`: sequences (heavy + light vs heavy only).
- `label`: includes antigen binding values like `"S+"`, `"S1+"`, `"S2+"` (positive) and others (negative).
- `subject`: donor / study ID for grouped cross-validation.


## Install dependencies (run once per session)

In [1]:
#!pip install -q antiberty transformers datasets scikit-learn biopython pyarrow

## Imports

In [1]:
import os
import random
from collections import Counter

import numpy as np
import pandas as pd
import torch

from sklearn.metrics import (
    precision_score, recall_score, f1_score,
    matthews_corrcoef, roc_auc_score,
    average_precision_score, balanced_accuracy_score
)

from sklearn.model_selection import StratifiedGroupKFold
from datasets import Dataset, DatasetDict, ClassLabel
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)

import antiberty

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


PyTorch: 2.3.0
CUDA available: True


device(type='cuda')

## Configuration

Update `DATA_DIR` and `OUTPUT_DIR` below if needed to the path where your models and data (the downloaded parquet file) is stored.


In [2]:
# Which column in the parquet to use as sequences
MODEL_TYPE = "HL"  

# Which dataset variant to use
SEQUENCE_SCOPE = "CDR3"  

# Path to your data directory
DATA_DIR = "../data/" 

# Path where models and logs will be saved
OUTPUT_DIR = "../models/"

# Training hyperparameters
BATCH_SIZE = 64        
LR = 1e-5              
N_EPOCHS = 10          

RANDOM_STATE_OUTER = 7 if SEQUENCE_SCOPE == "CDR3" else 9
RANDOM_STATE_INNER = 1

RUN_ID = f"S_antiBERTy_{MODEL_TYPE}_fine_tuning_{SEQUENCE_SCOPE}"
print("Run ID:", RUN_ID)
print("Data dir:", DATA_DIR)
print("Output dir:", OUTPUT_DIR)

Run ID: S_antiBERTy_HL_fine_tuning_CDR3
Data dir: ../data/
Output dir: ../models/


## Helper functions: model loading, freezing, formatting, metrics

In [None]:
def get_antiberty_paths():
    """Locate AntiBERTy model + vocab from the antiberty package."""
    project_path = os.path.dirname(os.path.realpath(antiberty.__file__))
    trained_dir = os.path.join(project_path, "trained_models")
    model_dir = os.path.join(trained_dir, "AntiBERTy_md_smooth")
    vocab = os.path.join(trained_dir, "vocab.txt")
    print("AntiBERTy model:", model_dir)
    print("AntiBERTy vocab:", vocab)
    return model_dir, vocab


def load_antiberty_classifier(num_labels: int = 2):
    """Load AntiBERTy as a sequence-classification model + tokenizer."""
    model_dir, vocab = get_antiberty_paths()
    tokenizer = transformers.BertTokenizer(
        vocab_file=vocab,
        do_lower_case=False
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir,
        num_labels=num_labels
    )
    model.to(device)
    size = sum(p.numel() for p in model.parameters())
    print(f"Model size: {size/1e6:.2f}M parameters")
    return model, tokenizer


def freeze_antiberty_layers(model, train_last_n_layers: int = 3):
    """Freeze embeddings and early encoder layers of AntiBERTy."""
    for p in model.bert.embeddings.parameters():
        p.requires_grad = False

    total_layers = len(model.bert.encoder.layer)  # AntiBERTy has 8 layers
    for layer in model.bert.encoder.layer[: total_layers - train_last_n_layers]:
        for p in layer.parameters():
            p.requires_grad = False
    return model


def insert_space_every_other_except_cls(s: str) -> str:
    """Add spaces between residues, keeping [CLS] intact."""
    parts = s.split("[CLS]")
    spaced = [" ".join(list(part)) for part in parts]
    out = " [CLS] ".join(spaced)
    return " ".join(out.split())


def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)


def compute_metrics(eval_pred):
    """Metrics callback for Hugging Face Trainer."""
    logits, labels = eval_pred
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    preds = np.argmax(logits, axis=1)

    return {
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds),
        "f1_weighted": f1_score(labels, preds, average="weighted"),
        "apr": average_precision_score(labels, probs),
        "balanced_accuracy": balanced_accuracy_score(labels, preds),
        "auc": roc_auc_score(labels, probs),
        "mcc": matthews_corrcoef(labels, preds),
    }

## Load dataset

We will load the data and transform the sequence format to the required format by AntiBERTy, which requiers the character `[CLS]` for padding the heavy and light chain sequences.

In [None]:
MAX_LENGTH = 512 - 2  # AntiBERTy max length minus specials

def load_data(scope: str = "CDR3", model_type: str = "HL"):
    if scope == "FULL":
        filename = "S_FULL.parquet"
        print("Loading full-length sequences...")
    else:
        filename = "S_CDR3.parquet"
        print("Loading CDR3 sequences...")

    path = os.path.join(DATA_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Could not find {path}. Please check DATA_DIR and filename.")

    df = pd.read_parquet(path)

    X = df[model_type].apply(lambda s: s[:MAX_LENGTH])
    X = X.str.replace("<cls><cls>", "[CLS][CLS]", regex=False)
    X = X.apply(insert_space_every_other_except_cls)

    y = np.isin(df["label"], ["S+", "S1+", "S2+"]).astype(int)
    groups = df["subject"].values

    print(f"Total sequences: {len(X)}")
    print(f"Unique donors: {len(np.unique(groups))}")
    print("Label counts:", Counter(y))

    return X, y, groups, df

X, y, y_groups, raw_df = load_data(SEQUENCE_SCOPE, MODEL_TYPE)

Loading CDR3 sequences...
Total sequences: 15539
Unique donors: 427
Label counts: Counter({1: 8658, 0: 6881})


## Create data splits

We generate train, validation and test data splist using the StratifiedGroupKFold function. We use a 25% of the dataset as test dataset and 33% of the training data as a validation set.

In [5]:
outer_cv = StratifiedGroupKFold(
    n_splits=4,
    shuffle=True,
    random_state=RANDOM_STATE_OUTER
)

inner_cv = StratifiedGroupKFold(
    n_splits=3,
    shuffle=True,
    random_state=RANDOM_STATE_INNER
)

# Use the first outer fold for this tutorial
for fold_idx, (train_index, test_index) in enumerate(outer_cv.split(X, y, y_groups), start=1):
    print(f"Using outer fold {fold_idx}")
    X_train_all, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train_all, y_test = y[train_index], y[test_index]
    y_groups_train = y_groups[train_index]
    break

print(f"Outer train size: {len(X_train_all)}, test size: {len(X_test)}")
print(f"% positive train: {np.mean(y_train_all):.3f}, % positive test: {np.mean(y_test):.3f}")

# Inner split to create validation set
for inner_idx, (inner_train_index, val_index) in enumerate(
    inner_cv.split(X_train_all, y_train_all, y_groups_train),
    start=1
):
    print(f"Using inner fold {inner_idx} for train/val split")
    X_train = X_train_all.iloc[inner_train_index]
    y_train = y_train_all[inner_train_index]
    X_val = X_train_all.iloc[val_index]
    y_val = y_train_all[val_index]
    break

print(f"Final sizes â€” train: {len(X_train)}, val: {len(X_val)}, test: {len(X_test)}")

Using outer fold 1
Outer train size: 12969, test size: 2570
% positive train: 0.555, % positive test: 0.566
Using inner fold 1 for train/val split
Final sizes â€” train: 8423, val: 4546, test: 2570


## Build Hugging Face Datasets & tokenize

We need to convert the different sets to the format required by Hugging Face.

In [6]:
train_df = pd.DataFrame({"sequence": X_train.values, "labels": y_train})
val_df   = pd.DataFrame({"sequence": X_val.values,   "labels": y_val})
test_df  = pd.DataFrame({"sequence": X_test.values,  "labels": y_test})

raw_datasets = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

model, tokenizer = load_antiberty_classifier(num_labels=2)
model = freeze_antiberty_layers(model, train_last_n_layers=3)

def preprocess_function(batch):
    encodings = tokenizer(
        batch["sequence"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )
    encodings["labels"] = batch["labels"]
    return encodings

tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=["sequence"],
)

AntiBERTy model: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth
AntiBERTy vocab: /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/vocab.txt


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /vast/palmer/home.mccleary/mw957/.conda/envs/bcrembed/lib/python3.12/site-packages/antiberty/trained_models/AntiBERTy_md_smooth and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model size: 25.76M parameters


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8423/8423 [00:02<00:00, 3736.40 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4546/4546 [00:01<00:00, 3854.12 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2570/2570 [00:00<00:00, 4093.28 examples/s]


## Set up the parameters for fine-tuning AntiBERTy

In [7]:
set_seed(1)

FOLD_ID = 1
OUT_PATH = os.path.join(OUTPUT_DIR, f"{RUN_ID}_Fold_{FOLD_ID}")
os.makedirs(OUT_PATH, exist_ok=True)
print("Saving checkpoints to:", OUT_PATH)

training_args = TrainingArguments(
    output_dir=OUT_PATH,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=N_EPOCHS,
    warmup_ratio=0.0,
    load_best_model_at_end=True,
    metric_for_best_model="auc",
    lr_scheduler_type="linear",
    seed=1
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Saving checkpoints to: ../models/S_antiBERTy_HL_fine_tuning_CDR3_Fold_1


## Fine tune AntiBERTy on the specificity dataset

In [8]:
train_result = trainer.train()
train_result.metrics

Epoch,Training Loss,Validation Loss,Precision,Recall,F1 Weighted,Apr,Balanced Accuracy,Auc,Mcc
1,0.6821,0.674185,0.571391,0.870452,0.512843,0.641419,0.535959,0.599263,0.097128
2,0.6683,0.668239,0.586726,0.841263,0.548888,0.662187,0.558285,0.619711,0.142084
3,0.6516,0.659277,0.617122,0.732107,0.594114,0.676344,0.588303,0.634982,0.184527
4,0.6417,0.65443,0.623169,0.731307,0.601285,0.688067,0.595238,0.647339,0.198104
5,0.6309,0.651899,0.628552,0.72531,0.606789,0.693455,0.600552,0.653557,0.207789
6,0.6232,0.657183,0.664266,0.590164,0.611359,0.695164,0.612686,0.654203,0.224292
7,0.6165,0.650168,0.643408,0.70052,0.618827,0.700375,0.612852,0.659554,0.229099
8,0.611,0.650533,0.650999,0.664534,0.618953,0.701372,0.614419,0.660647,0.229431
9,0.6062,0.650401,0.648743,0.670532,0.618158,0.703339,0.613261,0.662185,0.227533
10,0.6059,0.649999,0.642884,0.695322,0.617356,0.703881,0.611475,0.662833,0.225944


{'train_runtime': 628.6721,
 'train_samples_per_second': 133.981,
 'train_steps_per_second': 2.1,
 'total_flos': 6568285780076400.0,
 'train_loss': 0.6337348244406961,
 'epoch': 10.0}

## Evaluate the model on held-out test set

In [9]:
model.eval()
test_outputs = trainer.predict(tokenized_datasets["test"])
test_metrics = test_outputs.metrics

print("Test metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v:.4f}")

Test metrics:
test_loss: 0.6636
test_precision: 0.6531
test_recall: 0.6527
test_f1_weighted: 0.6074
test_apr: 0.6931
test_balanced_accuracy: 0.6005
test_auc: 0.6438
test_mcc: 0.2010
test_runtime: 8.7197
test_samples_per_second: 294.7350
test_steps_per_second: 4.7020


## Using the fine-tuned AntiBERTy model with Amulety on epitope-specific sequences

In this section we:
1. Export the fine-tuned AntiBERTy classifier as a Hugging Face model.
2. Register it as a **custom** model for Amulety.
3. Use Amulety to embed a few example sequences with this fine-tuned encoder.

This illustrates how you can fine-tune AntiBERTy for S-protein binding and then
reuse the same model inside Amulety to generate embeddings for other datasets
(e.g. epitope-specificity panels).


We will start by saving the fine-tuned AntiBERTy classifier as a Hugging Face model.

In [10]:
from pathlib import Path

# Directory where we will save the fine-tuned model for Amulety
CUSTOM_MODEL_PATH = Path(OUTPUT_DIR) / f"{RUN_ID}_amulety_custom"
CUSTOM_MODEL_PATH.mkdir(parents=True, exist_ok=True)

# `trainer` already holds the fine-tuned best model because we used `load_best_model_at_end=True`
# but to be explicit, we save the current `model` and `tokenizer`.
model.save_pretrained(CUSTOM_MODEL_PATH)
tokenizer.save_pretrained(CUSTOM_MODEL_PATH)

print("Saved fine-tuned model for Amulety at:", CUSTOM_MODEL_PATH)


Saved fine-tuned model for Amulety at: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom


We will then download a new dataset and sample 1000 cells for demonstration purposes.

In [33]:
airr_demo = pd.read_csv("https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv", sep='\t')
# cells that have at least one H and one L
mask_both = (
    airr_demo.groupby("cell_id")["chain_type"]
      .agg(lambda x: set(x))
      .pipe(lambda s: s[s.apply(lambda st: {"H", "L"}.issubset(st))])
)
cells_with_both = mask_both.index

# sample up to 1000 such cells (no replacement)
n_cells = min(1000, len(cells_with_both))
sampled_cells = np.random.choice(cells_with_both, size=n_cells, replace=False)
df_sampled = airr_demo[airr_demo["cell_id"].isin(sampled_cells)].copy()
df_sampled.head()

  airr_demo = pd.read_csv("https://zenodo.org/records/17186858/files/ML_bcr_airr_dataset.tsv", sep='\t')


Unnamed: 0,sequence_id,sequence_vdj_aa,locus,cell_id,chain_type,v_call,v_call_family,j_call_family,mu_freq,junction_aa_length,isotype,source,subject,specificity,duplicate_count,productive,rev_comp,stop_codon,vj_in_frame
79,80_heavy,EVQLVESGGGLVQPGGSLRLSCVASGFTFSSYWMSWVRQAPGKGLE...,IGH,cell_80,H,IGHV3-7,IGHV3,IGHJ4,0.003378,18.0,,OAS,OAS_King_Subject-BCP3,unlabeled,1,True,False,False,True
189,190_heavy,EVQLVQSGAEVKKPGESLKISCKGSAYSFTNYWIAWVRQMPGKGLE...,IGH,cell_190,H,IGHV5-51,IGHV5,IGHJ3,0.013158,20.0,,OAS,OAS_King_Subject-BCP3,unlabeled,1,True,False,False,True
298,299_heavy,EVQLVESGGGLVKPGGSLRLSCSASRFTFSTYRMNWVRQAPGKGLE...,IGH,cell_299,H,IGHV3-21,IGHV3,IGHJ2,0.023256,23.0,,OAS,OAS_King_Subject-BCP3,unlabeled,1,True,False,False,True
1084,1085_heavy,QVQLVQSGAEVREPGASVKVSCKASGYTFTIYDINWVRQAPGQGLE...,IGH,cell_1085,H,IGHV1-8,IGHV1,IGHJ4,0.062706,13.0,,OAS,OAS_King_Subject-BCP3,unlabeled,1,True,False,False,True
2712,2713_heavy,EVQLVESGGGLVQRGGSLRLSCGASGFTFSSYNMNWVRQAPGKGLE...,IGH,cell_2713,H,IGHV3-48,IGHV3,IGHJ3,0.032895,14.0,,OAS,OAS_King_Subject-BCP4,unlabeled,1,True,False,False,True


We will then use amulety to generate embedding vectors for these sequences using our fine-tuned AntiBERTy model.

In [31]:
from amulety import embed_airr

CHAIN = MODEL_TYPE
EMBED_DIM = int(model.config.hidden_size)

print("Custom model path:", CUSTOM_MODEL_PATH)
print("Embedding dimension:", EMBED_DIM)
print("Max length used by tutorial:", MAX_LENGTH)

embeddings_df, meta_df = embed_airr(
    airr=df_sampled,
    chain=CHAIN,
    model="custom",
    sequence_col="sequence_vdj_aa",
    cell_id_col="cell_id",
    batch_size=8,
    model_path=str(CUSTOM_MODEL_PATH),
    embedding_dimension=EMBED_DIM,
    max_length=MAX_LENGTH,
    output_type="df",       # return embeddings as a DataFrame
    residue_level=False,    # sequence-level embeddings
)

print("Embeddings shape:", embeddings_df.shape)
embeddings_df.head()

Some weights of BertForMaskedLM were not initialized from the model checkpoint at ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Custom model path: ../models/S_antiBERTy_HL_fine_tuning_CDR3_amulety_custom
Embedding dimension: 512
Max length used by tutorial: 510
Embeddings shape: (1000, 513)


Unnamed: 0,cell_id,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_503,dim_504,dim_505,dim_506,dim_507,dim_508,dim_509,dim_510,dim_511,dim_512
0,cell_100681,-0.077534,-0.109347,-1.533826,-0.553431,1.118874,0.011019,1.411259,-1.877443,1.221428,...,-0.486696,0.669253,-0.199235,0.142055,-0.412014,0.974601,-0.765339,-0.230175,0.045813,-0.22751
1,cell_101224,-0.077534,-0.109347,-1.533826,-0.553431,1.118874,0.011019,1.411259,-1.877443,1.221428,...,-0.486696,0.669253,-0.199235,0.142055,-0.412014,0.974601,-0.765339,-0.230175,0.045813,-0.22751
2,cell_101233,-0.077534,-0.109347,-1.533826,-0.553431,1.118874,0.011019,1.411259,-1.877443,1.221428,...,-0.486696,0.669253,-0.199235,0.142055,-0.412014,0.974601,-0.765339,-0.230175,0.045813,-0.22751
3,cell_101360,-0.077534,-0.109347,-1.533826,-0.553431,1.118874,0.011019,1.411259,-1.877443,1.221428,...,-0.486696,0.669253,-0.199235,0.142055,-0.412014,0.974601,-0.765339,-0.230175,0.045813,-0.22751
4,cell_102058,-0.077534,-0.109347,-1.533826,-0.553431,1.118874,0.011019,1.411259,-1.877443,1.221428,...,-0.486696,0.669253,-0.199235,0.142055,-0.412014,0.974601,-0.765339,-0.230175,0.045813,-0.22751


At this point:
- `embeddings_df` contains one row per input sequence and 512 embedding features
  (plus identifiers), generated by the **fine-tuned** AntiBERTy model via Amulety.

You can now:
- Project these embeddings with UMAP / t-SNE,
- Train simple classifiers or regressors for predictions,
- Or integrate them into larger downstream pipelines, all using the same
  Amulety interface you use for other models.