# Fine-Tune an LLM for Antibody Sequence Generation

In [None]:
# pip install -r ../requirements.txt

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import torch

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.empty_cache()

## For Pytorch, see if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
## Load dataset
df = pd.read_csv("../data/sabdab/sabdab_with_sequences.tsv", sep='\t')

## Remove rows with missing sequences
df = df.dropna(subset=['HeavySeq', 'LightSeq', 'AntigenSeq'])

df.head()

Unnamed: 0,pdb,Hchain,Lchain,AntigenChains,HeavySeq,LightSeq,AntigenSeq
4,8xa4,C,D,A | B,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,SCNGLYYQGSCYILHSDYKSFEDAKANCAAESSTLPNKSDVLTTWL...
9,9cph,H,L,A,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...
10,9d7i,H,G,E,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
11,9d7i,J,I,C,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
12,9d7o,H,G,E,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...


In [4]:
## Format prompts
def format_prompt(example):
    return {
        "text": f"Antigen: {example['AntigenSeq']}\nAntibody: {example['HeavySeq']}|{example['LightSeq']}\n"
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_prompt)


Map: 100%|██████████| 10073/10073 [00:02<00:00, 4625.80 examples/s]


In [5]:
## Load base tokenizer and model
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, trust_remote_code=True)

## Convert model to float32 for training
model = model.to(torch.float32)

Loading checkpoint shards: 100%|██████████| 6/6 [00:30<00:00,  5.11s/it]


In [6]:
## Extend tokenizer with special tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["|"]# ["[", "]", "|"]

In [None]:
## Check if tokens already exist in the tokenizer's vocabulary
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))

model.train()

Embedding(100352, 5120, padding_idx=100349)

In [24]:
## Tokenize the dataset
def tokenize(example):
    # return tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    encoded = tokenizer(example["text"], padding="max_length", truncation=True, max_length=256)
    # encoded["labels"] = encoded["input_ids"]#.copy()
    return encoded

tokenized_dataset = dataset.map(tokenize)

## Remove unnecessary columns from the tokenized dataset
tokenized_dataset = tokenized_dataset.remove_columns([
    'pdb', 'Hchain', 'Lchain', 'AntigenSeq', 'AntigenChains',
    'HeavySeq', 'LightSeq', '__index_level_0__', 'text'
])

Map: 100%|██████████| 10073/10073 [00:10<00:00, 989.07 examples/s]


In [27]:
## Split the dataset into train and validation sets
train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=1337)

train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 8058
})

In [28]:
## Training arguments
training_args = TrainingArguments(
    output_dir=f"../models/peleke-{model_name.split('/')[-1]}",
    ## Batching
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    ## Epochs and warmups
    num_train_epochs=3,
    warmup_steps=100,
    ## Optimization
    weight_decay=0.01,
    ## Logging and saving
    logging_dir="../logs",
    logging_steps=50,
    save_strategy="epoch",
    # fp16=True,
    gradient_checkpointing=True, ## If having memory issues
    report_to="none"
)

In [29]:
## PEFT configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
    target_modules=["o_proj", "qkv_proj"],
)

peft_model = get_peft_model(model, peft_config)

In [30]:
## Data Collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  ## Important: MLM=False for causal LM
)

In [31]:
## Trainer
trainer = Trainer(
    # model=model,
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
## Fine-tune
trainer.train()

Step,Training Loss
