# Fine-Tune an LLM for Antibody Sequence Generation

In [None]:
# pip install -r ../requirements.txt

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import torch
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Test your GPU setup
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Number of GPUs: 2
GPU 0: NVIDIA GeForce RTX 5090
Memory: 33.7 GB
GPU 1: NVIDIA GeForce RTX 3090 Ti
Memory: 25.3 GB


In [4]:
## Load dataset
df = pd.read_csv("../data/sabdab/sabdab_sequences.csv")

## Remove rows with missing sequences
df = df.dropna(subset=['h_chain_seq', 'l_chain_seq', 'antigen_seqs'])

df.head()

Unnamed: 0,pdb_id,h_chain_id,l_chain_id,antigen_ids,h_chain_seq,l_chain_seq,antigen_seqs
0,8xa4,C,D,A|B,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,SCNGLYYQGSCYILHSDYKSFEDAKANCAAESSTLPNKSDVLTTWL...
2,9cph,H,L,A,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...
3,9d7i,H,G,E,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
4,9d7i,J,I,C,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...
5,9d7o,H,G,E,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...


In [5]:
## Format prompts
def format_prompt(example):
    return {
        "text": f"Antigen: {example['antigen_seqs']}\nAntibody: {example['h_chain_seq']}|{example['l_chain_seq']}\n"
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_prompt)


Map: 100%|██████████| 9560/9560 [00:00<00:00, 27747.38 examples/s]


### This script is used to clear vram. For testing purposes only and when you want to clear the GPU memory.

In [2]:
import gc
import torch
# Clear any existing models from GPU memory
torch.cuda.empty_cache()
gc.collect()

# Check current GPU memory usage
print(f"GPU Memory before: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB allocated")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")
# If you have a model loaded, delete it first
try:
    del model
    torch.cuda.empty_cache()
    gc.collect()
    print("Previous model cleared from memory")
except:
    print("No previous model to clear")

GPU Memory before: 0.00 GB allocated
GPU Memory reserved: 0.00 GB reserved
No previous model to clear


In [6]:
## Load base tokenizer and model
 
model_name =  "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map="auto",  # Use only the first GPU
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,  # Load model in bfloat16 for better performance
    low_cpu_mem_usage=True,  # Reduce CPU memory usage during loading
    )  

## Convert model to float32 for training
#model = model.to(torch.float32)

Loading checkpoint shards: 100%|██████████| 6/6 [00:11<00:00,  1.92s/it]


In [7]:
## Extend tokenizer with special tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["|"]# ["[", "]", "|"]

In [8]:
## Check if tokens already exist in the tokenizer's vocabulary
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))

model.train()

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(100352, 5120, padding_idx=100349)
    (layers): ModuleList(
      (0-39): 40 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (qkv_proj): Linear(in_features=5120, out_features=7680, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=5120, out_features=35840, bias=False)
          (down_proj): Linear(in_features=17920, out_features=5120, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((5120,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=5120, out_features=1003

In [9]:
## Tokenize the dataset
def tokenize(example):
    # return tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    encoded = tokenizer(example["text"], padding="max_length", truncation=True, max_length=256)
    # encoded["labels"] = encoded["input_ids"]#.copy()
    return encoded

tokenized_dataset = dataset.map(tokenize)

## Remove unnecessary columns from the tokenized dataset
tokenized_dataset = tokenized_dataset.remove_columns([
    'pdb_id', 'h_chain_id', 'l_chain_id', 'antigen_seqs', 'antigen_ids',
    'h_chain_seq', 'l_chain_seq', '__index_level_0__', 'text'
])

Map: 100%|██████████| 9560/9560 [00:02<00:00, 3386.81 examples/s]


In [10]:
## Split the dataset into train and validation sets
train_test_split = tokenized_dataset.train_test_split(test_size=0.2, seed=1337)

train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 7648
})

In [18]:
## Training arguments
training_args = TrainingArguments(
    output_dir=f"../models/peleke-{model_name.split('/')[-1]}",
    ## Batching
    per_device_train_batch_size=1, # Adjust based on GPU memory
    gradient_accumulation_steps=16,  # Adjust based on GPU memory
    per_device_eval_batch_size=1, # Adjust based on GPU memory
    ## Epochs and warmups
    num_train_epochs=3,
    warmup_steps=25, 
    ## Optimization
    weight_decay=0.01,
    ## Logging and saving
    logging_dir="../logs",
    logging_steps=50,
    save_strategy="epoch",
    # fp16=True,
    gradient_checkpointing=True, ## If having memory issues
    report_to="none"
)

In [19]:
## PEFT configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
    target_modules=["o_proj", "qkv_proj"],
)

peft_model = get_peft_model(model, peft_config)



In [20]:
## Data Collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  ## Important: MLM=False for causal LM
)

In [21]:
## Trainer
trainer = Trainer(
    # model=model,
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [22]:
# Enable gradients for input embeddings
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()
else:
    # Manual approach
    def make_inputs_require_grad(module, input, output):
        output.requires_grad_(True)
    
    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

In [23]:
## Fine-tune
trainer.train()

Step,Training Loss
50,3.5168
100,3.2332
150,3.0309
200,2.9624
250,2.843
300,2.6463
350,2.6743
400,2.6132
450,2.5759
500,2.5372


TrainOutput(global_step=1434, training_loss=2.472764045789817, metrics={'train_runtime': 10148.174, 'train_samples_per_second': 2.261, 'train_steps_per_second': 0.141, 'total_flos': 4.9878253996867584e+17, 'train_loss': 2.472764045789817, 'epoch': 3.0})