# Fine-Tune an LLM for Antibody Sequence Generation

In [1]:
# pip install -r ../requirements.txt

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset
from trl import SFTTrainer
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import torch
import os
import wandb
from transformers import TrainerCallback

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Test your GPU setup
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Number of GPUs: 2
GPU 0: NVIDIA GeForce RTX 5090
Memory: 33.7 GB
GPU 1: NVIDIA GeForce RTX 3090 Ti
Memory: 25.3 GB


In [3]:
# Configure wandb for your local server
os.environ["WANDB_BASE_URL"] = "http://192.168.0.24:8080"

In [None]:
# Login to your wandb instance
# wandb.login()
# This will prompt for your API key or you can set it programmatically:
os.environ["WANDB_API_KEY"] = ""

In [5]:
## Load dataset
df = pd.read_csv("../data/sabdab/sabdab_training_dataset.csv")

df.columns


Index(['pdb_id', 'h_chain_id', 'l_chain_id', 'antigen_ids', 'h_chain_seq',
       'l_chain_seq', 'antigen_seqs', 'antibody_sequences',
       'highlighted_epitope_sequences', 'epitope_residues'],
      dtype='object')

In [6]:
## Remove rows with missing sequences
df = df.dropna(subset=['h_chain_seq', 'l_chain_seq', 'antigen_seqs', 'highlighted_epitope_sequences'])

df.head()

Unnamed: 0,pdb_id,h_chain_id,l_chain_id,antigen_ids,h_chain_seq,l_chain_seq,antigen_seqs,antibody_sequences,highlighted_epitope_sequences,epitope_residues
0,8xa4,C,D,A|B,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,SCNGLYYQGSCYILHSDYKSFEDAKANCAAESSTLPNKSDVLTTWL...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,SCNGLYYQGSCYI[L]HSD[Y]KSFEDAKANCAAESSTLPNKSDVL...,A:ARG 176|A:ASP 146|A:ASP 150|A:ASP 170|A:GLN ...
1,9cph,H,L,A,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...,A:ALA 1116|A:ALA 1122|A:ALA 1128|A:ALA 900|A:A...
2,9d7i,H,G,E,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,E:ARG 429|E:ARG 469|E:ASN 177|E:ASN 197|E:ASN ...
3,9d7i,J,I,C,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,C:ARG 469|C:ASN 197|C:ASN 280|C:ASN 425|C:ASP ...
4,9d7o,H,G,E,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,E:ARG 429|E:ARG 469|E:ASN 197|E:ASN 280|E:ASN ...


In [7]:
## Load base tokenizer and model FIRST
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map= {'model.embed_tokens': 0, 
                 'model.layers.0': 0, 
                 'model.layers.1': 0, 
                 'model.layers.2': 0, 
                 'model.layers.3': 0, 
                 'model.layers.4': 0, 
                 'model.layers.5': 0, 
                 'model.layers.6': 0, 
                 'model.layers.7': 0, 
                 'model.layers.8': 0, 
                 'model.layers.9': 0, 
                 'model.layers.10': 0, 
                 'model.layers.11': 0, 
                 'model.layers.12': 0, 
                 'model.layers.13': 0, 
                 'model.layers.14': 0, 
                 'model.layers.15': 0, 
                 'model.layers.16': 0, 
                 'model.layers.17': 0, 
                 'model.layers.18': 0, 
                 'model.layers.19': 0, 
                 'model.layers.20': 0, 
                 'model.layers.21': 0, 
                 'model.layers.22': 0, 
                 'model.layers.23': 0, 
                 'model.layers.24': 0, 
                 'model.layers.25': 0, 
                 'model.layers.26': 1, 
                 'model.layers.27': 1, 
                 'model.layers.28': 1, 
                 'model.layers.29': 1, 
                 'model.layers.30': 1, 
                 'model.layers.31': 1, 
                 'model.layers.32': 1, 
                 'model.layers.33': 1, 
                 'model.layers.34': 1, 
                 'model.layers.35': 1, 
                 'model.layers.36': 1, 
                 'model.layers.37': 1, 
                 'model.layers.38': 1, 
                 'model.layers.39': 1, 
                 'model.norm': 1, 
                 'model.rotary_emb': 1, 
                 'lm_head': 1}, # Use only the first GPU
    trust_remote_code=True,
    torch_dtype=torch.float16, # Load model in bfloat16 for better performance
    low_cpu_mem_usage=True, # Reduce CPU memory usage during loading
    max_memory={0: "30GB", 1:"23.5GB"}  # Limit GPU memory usage to 20GB
)

Loading checkpoint shards: 100%|██████████| 6/6 [00:06<00:00,  1.07s/it]


In [8]:
# Check current device mapping
print("Current device map:", model.hf_device_map)

Current device map: {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.layers.32': 1, 'model.layers.33': 1, 'model.layers.34': 1, 'model.layers.35': 1, 'model.layers.36': 1, 'model.layers.37': 1, 'model.layers.38': 1, 'model.layers.39': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}


In [9]:
# Add epitope tokens
epitope_tokens = ["<epi>", "</epi>"]
tokenizer.add_special_tokens({"additional_special_tokens": epitope_tokens})



2

In [10]:
# Add amino acid tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["|"]
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
tokenizer.add_tokens(new_tokens)

# Resize model embeddings ONCE after adding all tokens
model.resize_token_embeddings(len(tokenizer))
model.train()


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(100354, 5120, padding_idx=100349)
    (layers): ModuleList(
      (0-39): 40 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (qkv_proj): Linear(in_features=5120, out_features=7680, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=5120, out_features=35840, bias=False)
          (down_proj): Linear(in_features=17920, out_features=5120, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((5120,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=5120, out_features=1003

In [11]:
# Convert epitope format function
import re
def convert_epitope_format(sequence):
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

## NOW create dataset with all tokens available
def format_prompt(example):
    epitope_seq = convert_epitope_format(example['highlighted_epitope_sequences'])
    return {
        "text": f"Antigen: {epitope_seq}<|im_end|>\nAntibody: {example['antibody_sequences']}<|im_end|>\n"
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_prompt)

Map: 100%|██████████| 9529/9529 [00:00<00:00, 20998.28 examples/s]


In [12]:
# Add task-specific tokens
task_tokens = ["Antigen", "Antibody", "Epitope"]
tokenizer.add_tokens(task_tokens)
model.resize_token_embeddings(len(tokenizer))

 

Embedding(100357, 5120, padding_idx=100349)

In [13]:
## Tokenize the dataset
def tokenize(example):
    encoded = tokenizer(example["text"], truncation=True, max_length=1024)
    # Make sure labels are a proper list, not nested
    encoded["labels"] = encoded["input_ids"].copy()
    return encoded

tokenized_dataset = dataset.map(tokenize)

Map: 100%|██████████| 9529/9529 [00:02<00:00, 3272.67 examples/s]


In [14]:
# Verify tokenization is working with epitope tokens
print("Sample tokenized text:")
sample_tokens = tokenizer.tokenize(dataset[0]['text'][:200])
print(sample_tokens)

Sample tokenized text:
['Antigen', ':', 'ĠSCN', 'GL', 'YY', 'Q', 'G', 'SC', 'Y', 'I', '<epi>', 'L', '</epi>', 'H', 'SD', '<epi>', 'Y', '</epi>', 'K', 'SF', 'ED', 'AK', 'AN', 'CAA', 'ES', 'ST', 'LP', 'NK', 'SD', 'VL', 'TT', 'W', 'LI', '<epi>', 'D', '</epi>', '<epi>', 'Y', '</epi>', 'V', '<epi>', 'E', '</epi>', '<epi>', 'D', '</epi>', '<epi>', 'T', '</epi>', 'WG', 'SD', 'GN', 'P', 'IT', 'K', 'TT', 'SD', '<epi>', 'Y', '</epi>', 'Q', 'DS', '<epi>', 'D', '</epi>', 'VS', '<epi>', 'Q', '</epi>', '<epi>', 'E']


In [15]:
# Remove unnecessary columns
tokenized_dataset = tokenized_dataset.remove_columns([
    'pdb_id', 'h_chain_id', 'l_chain_id', 'antigen_ids', 'antigen_seqs',
    'h_chain_seq', 'l_chain_seq', 'antibody_sequences',
    'highlighted_epitope_sequences', 'epitope_residues', 'text'
])



In [16]:
# Apply the gradient fix to your model
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()
else:
    def make_inputs_require_grad(module, input, output):
        output.requires_grad_(True)
    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

In [17]:
# # Create data collator
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer,
#     mlm=False,
#     return_tensors="pt",
#     pad_to_multiple_of=8, # Pad to multiple of 8 for better performance on GPUs
# )
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    return_tensors="pt",
)

In [18]:
 # Configure LoRA
## PEFT configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
    target_modules=["o_proj", "qkv_proj"],
)

# Apply LoRA to the model
model = get_peft_model(model, peft_config)

# Print trainable parameters
model.print_trainable_parameters()



trainable params: 7,372,800 || all params: 14,666,931,200 || trainable%: 0.0503


In [19]:
# Update training arguments to enable wandb logging
training_args = TrainingArguments(
    output_dir=f"../models/peleke-{model_name.split('/')[-1]}-07222025",
    per_device_train_batch_size=9,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=6,
    num_train_epochs=3,
    warmup_steps=25,
    weight_decay=0.01,
    learning_rate=2e-4,
    logging_dir="../logs",
    logging_steps=25,
    gradient_checkpointing=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="wandb",  # Enable wandb reporting
    run_name=f"lora-epitope-{model_name.split('/')[-1]}",  # Run name for wandb
    # optim="adamw_torch",
    fp16=True,  # Enable mixed precision training
    dataloader_num_workers=8,  # Add parallel data loading
    dataloader_pin_memory=True,  # Pin memory for faster data loading
    remove_unused_columns=False,
    max_grad_norm=1.0,
)

In [20]:
# Function to extract LoRA config dynamically
def get_lora_config_from_model(model):
    """Extract LoRA configuration from a PEFT model"""
    if hasattr(model, 'peft_config') and model.peft_config:
        # Get the first (and usually only) PEFT config
        peft_config = list(model.peft_config.values())[0]
        
        return {
            "lora_r": peft_config.r,
            "lora_alpha": peft_config.lora_alpha,
            "lora_dropout": peft_config.lora_dropout,
            "lora_bias": peft_config.bias,
            "lora_task_type": str(peft_config.task_type),
            "lora_target_modules": peft_config.target_modules,
            "lora_fan_in_fan_out": getattr(peft_config, 'fan_in_fan_out', False),
            "lora_init_lora_weights": getattr(peft_config, 'init_lora_weights', True),
        }
    else:
        return {"lora_config": "No PEFT config found"}

# Initialize wandb with dynamic LoRA config
wandb.init(
    project="phi4-antibody-epitope",
    name=f"lora-r{get_lora_config_from_model(model).get('lora_r', 'unknown')}-{model_name.split('/')[-1]}-{training_args.per_device_train_batch_size}bs",
    config={
        # Model Configuration
        "model": model_name,
        "model_type": "phi-4",
        "total_parameters": sum(p.numel() for p in model.parameters()),
        "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
        "trainable_percentage": round(sum(p.numel() for p in model.parameters() if p.requires_grad) / sum(p.numel() for p in model.parameters()) * 100, 4),
        
        # Dynamic LoRA Configuration
        **get_lora_config_from_model(model),  # Unpack LoRA config
        
        # Training Configuration
        "batch_size": training_args.per_device_train_batch_size,
        "effective_batch_size": training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
        "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
        "learning_rate": getattr(training_args, 'learning_rate', 5e-5),
        "optimizer": training_args.optim,
        "num_epochs": training_args.num_train_epochs,
        "warmup_steps": training_args.warmup_steps,
        "weight_decay": training_args.weight_decay,
        "fp16": training_args.fp16,
        "gradient_checkpointing": training_args.gradient_checkpointing,
        
        # Data Configuration
        "max_seq_length": 1024,
        "dataset_size": len(tokenized_dataset),
        "epitope_tokens_added": len(epitope_tokens) if 'epitope_tokens' in locals() else 2,
        "data_format": "antigen_epitope_to_antibody",
        
        # Hardware
        "gpus_used": torch.cuda.device_count(),
        "gpu_names": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
    },
    tags=["phi4", "lora", "antibody", "epitope", "fine-tuning"],
    notes="Fine-tuning Phi-4 with LoRA for antibody generation from epitope-highlighted antigens"
)



class GPUMemoryCallback(TrainerCallback):
    def __init__(self, log_every_n_steps=50):
        self.log_every_n_steps = log_every_n_steps
    
    def on_log(self, args, state, control, **kwargs):
        """Log GPU memory on logging steps"""
        if state.global_step % self.log_every_n_steps == 0:
            for i in range(torch.cuda.device_count()):
                allocated = torch.cuda.memory_allocated(i) / 1024**3
                cached = torch.cuda.memory_reserved(i) / 1024**3
                wandb.log({
                    f"gpu_{i}/memory_allocated_gb": allocated,
                    f"gpu_{i}/memory_cached_gb": cached,
                }, step=state.global_step)

# Verify the config was captured correctly
print("LoRA Config captured:", get_lora_config_from_model(model))

[34m[1mwandb[0m: Currently logged in as: [33mnicholas1-santolla[0m to [32mhttp://192.168.0.24:8080[0m. Use [1m`wandb login --relogin`[0m to force relogin


LoRA Config captured: {'lora_r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05, 'lora_bias': 'none', 'lora_task_type': 'TaskType.CAUSAL_LM', 'lora_target_modules': {'qkv_proj', 'o_proj'}, 'lora_fan_in_fan_out': False, 'lora_init_lora_weights': True}


In [21]:
#log custom metrics during training
def log_custom_metrics():
    wandb.log({
        "epitope_tokens_added": len(epitope_tokens),
        "total_parameters": sum(p.numel() for p in model.parameters()),
        "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),
    })


log_custom_metrics()

In [None]:
# Set up SFTTrainer
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    args=training_args,
    callbacks=[GPUMemoryCallback(log_every_n_steps=50)],  # Log every 50 steps
)

# Start training
print("Starting training...")
trainer.train()

# Finish wandb run
wandb.finish()

# Save the trained model
trainer.save_model()
print(f"Model saved to {training_args.output_dir}")

Truncating train dataset: 100%|██████████| 9529/9529 [00:00<00:00, 207359.58 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Starting training...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Step,Training Loss
25,4.4935
50,3.6406
75,3.5135


### This script is used to clear vram. For testing purposes only and when you want to clear the GPU memory.

In [10]:
import gc
import torch
# Clear any existing models from GPU memory
torch.cuda.empty_cache()
gc.collect()

# Check current GPU memory usage
print(f"GPU Memory before: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB allocated")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")
# If you have a model loaded, delete it first
try:
    del model
    torch.cuda.empty_cache()
    gc.collect()
    print("Previous model cleared from memory")
except:
    print("No previous model to clear")

GPU Memory before: 13.99 GB allocated
GPU Memory reserved: 15.27 GB reserved
Previous model cleared from memory
