# Fine-Tune an LLM for Antibody Sequence Generation

In [1]:
# pip install -r ../requirements.txt
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from peft import get_peft_model, LoraConfig, TaskType
import torch

import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## Pytorch, see if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
## Load dataset
df = pd.read_csv("../data/sabdab/sabdab_with_sequences.tsv", sep='\t')
## Remove rows with missing sequences
df = df.dropna(subset=['HeavySeq', 'LightSeq', 'AntigenSeq'])
# df["text"] = df.apply(lambda x: f"Antigen: {x['antigen']}\nAntibody: {x['antibody']}", axis=1)
# dataset = Dataset.from_pandas(df)

df

Unnamed: 0,pdb,Hchain,Lchain,AntigenChains,HeavySeq,LightSeq,AntigenSeq
4,8xa4,C,D,A | B,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,SCNGLYYQGSCYILHSDYKSFEDAKANCAAESSTLPNKSDVLTTWL...
9,9cph,H,L,A,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...
10,9d7i,H,G,E,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
11,9d7i,J,I,C,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
12,9d7o,H,G,E,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
...,...,...,...,...,...,...,...
18296,7tas,H,L,E,QVQLVESGGVVVQPGGSLRLSCAASGFTFHDHTMHWVRQAPGKGLE...,QSVLTQPPSASGTPGQRVTISCSGSSANIGSNTVNWYQHLPGTAPK...,LCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCY...
18298,7lo6,J,I,C,EVQLVESGAEVKKPGSSVKVSCKASGDTFIRYSFTWVRQAPGQGLE...,DIVMTQSPATLSVSPGERATLSCRASESVSSDLAWYQQKPGQAPRL...,NLWVTVYYGVPVWKDAETTLFCASDAKAYETEKHNVWATHACVPTD...
18299,3vi3,H,L,D,QVHLQQSGAELMKPGASVKISCKATGYTFTSYWIEWVKQRPGHGLE...,DIVMTQATPSIPVTPGESVSISCRSNKSLLHSNGNTYLYWFLQRPG...,NRCLKANAKSCGECIQAGPNCGWCTNSTFRCDDLEALKKKGCPPDD...
18300,6zdg,F,G,D,EVQLVESGGGVVQPGRSLRLSCAASAFTFSSYDMHWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKL...,TNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFK...


In [6]:
## Format prompts
def format_prompt(example):
    return {
        "text": f"Antigen: {example['AntigenSeq']}\nAntibody: {example['HeavySeq']}|{example['LightSeq']}\n"
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_prompt)


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10073/10073 [00:01<00:00, 5142.64 examples/s]


In [7]:
## Load base tokenizer and model
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=0, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:28<00:00,  4.76s/it]


In [8]:
## Extend tokenizer with special tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["[", "]", "|"]

In [9]:
## Check if tokens already exist in the tokenizer's vocabulary
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))

Embedding(100352, 5120, padding_idx=100349)

In [10]:
## Tokenize the dataset
def tokenize(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10073/10073 [00:11<00:00, 880.30 examples/s]


In [11]:
## Training arguments
training_args = TrainingArguments(
    output_dir="../models/peleke-phi4",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    save_strategy="epoch",
    evaluation_strategy="no",
    fp16=True,
    gradient_checkpointing=True ## If having memory issues
)



In [None]:
## PEFT configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
)

peft_model = get_peft_model(model, peft_config)

In [12]:
## Trainer
trainer = Trainer(
    model=model,
    # model=peft_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
)

  trainer = Trainer(


In [None]:
## Fine-tune
trainer.train()