# Fine-Tune an LLM for Antibody Sequence Generation

In [1]:
# pip install -r ../requirements.txt

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset
from trl import SFTTrainer
from peft import get_peft_model, LoraConfig, TaskType
import pandas as pd
import torch
import os
from transformers import TrainerCallback

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Test your GPU setup
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Number of GPUs: 2
GPU 0: NVIDIA GeForce RTX 5090
Memory: 33.7 GB
GPU 1: NVIDIA GeForce RTX 3090 Ti
Memory: 25.3 GB


In [2]:
## Load dataset
df = pd.read_csv("../data/sabdab/sabdab_training_dataset.csv")

df.columns


Index(['pdb_id', 'h_chain_id', 'l_chain_id', 'antigen_ids', 'h_chain_seq',
       'l_chain_seq', 'antigen_seqs', 'antibody_seqs', 'h_chain_fv_seq',
       'l_chain_fv_seq', 'antibody_fv_seqs', 'highlighted_epitope_seqs',
       'epitope_residues'],
      dtype='object')

In [3]:
## Remove rows with missing sequences
df = df.dropna(subset=['h_chain_seq', 'l_chain_seq', 'antigen_seqs', 'highlighted_epitope_seqs'])

df.head()

Unnamed: 0,pdb_id,h_chain_id,l_chain_id,antigen_ids,h_chain_seq,l_chain_seq,antigen_seqs,antibody_seqs,h_chain_fv_seq,l_chain_fv_seq,antibody_fv_seqs,highlighted_epitope_seqs,epitope_residues
0,8xa4,C,D,A|B,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,SCNGLYYQGSCYILHSDYKSFEDAKANCAAESSTLPNKSDVLTTWL...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,EIVLTQSPGTLSLSPGERVTLSCRASQRVSSTYLAWYQQKPGQAPR...,QLQLQESGPGLVKPSETLSLTCTVSGGSISSNNDYWGWIRQPPGKG...,SCNGLYYQGSCYI[L]HSD[Y]KSFEDAKANCAAESSTLPNKSDVL...,A:ARG 176|A:ASP 146|A:ASP 150|A:ASP 170|A:GLN ...
1,9cph,H,L,A,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLL...,EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLE...,KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEK...,A:ALA 1116|A:ALA 1122|A:ALA 1128|A:ALA 900|A:A...
2,9d7i,H,G,E,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,E:ARG 429|E:ARG 469|E:ASN 177|E:ASN 197|E:ASN ...
3,9d7i,J,I,C,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,VQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLEW...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,C:ARG 469|C:ASN 197|C:ASN 280|C:ASN 425|C:ASP ...
4,9d7o,H,G,E,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,YELTQPPSVSVSPGQTATITCSGASTNVCWYQVKPGQSPEVVIFEN...,QVQLQESGPGVVKSSETLSLTCTVSGGSMGGTYWSWLRLSPGKGLE...,LWVTVYYGVPVWKDAETTLFCASDNVWATHACVPTDPNPQEIHLEN...,E:ARG 429|E:ARG 469|E:ASN 197|E:ASN 280|E:ASN ...


In [4]:
## Load base tokenizer and model FIRST
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map= {'model.embed_tokens': 0, 
                 'model.layers.0': 0, 
                 'model.layers.1': 0, 
                 'model.layers.2': 0, 
                 'model.layers.3': 0, 
                 'model.layers.4': 0, 
                 'model.layers.5': 0, 
                 'model.layers.6': 0, 
                 'model.layers.7': 0, 
                 'model.layers.8': 0, 
                 'model.layers.9': 0, 
                 'model.layers.10': 0, 
                 'model.layers.11': 0, 
                 'model.layers.12': 0, 
                 'model.layers.13': 0, 
                 'model.layers.14': 0, 
                 'model.layers.15': 0, 
                 'model.layers.16': 0, 
                 'model.layers.17': 0, 
                 'model.layers.18': 0, 
                 'model.layers.19': 0, 
                 'model.layers.20': 0, 
                 'model.layers.21': 0, 
                 'model.layers.22': 0, 
                 'model.layers.23': 0, 
                 'model.layers.24': 0, 
                 'model.layers.25': 0, 
                 'model.layers.26': 1, 
                 'model.layers.27': 1, 
                 'model.layers.28': 1, 
                 'model.layers.29': 1, 
                 'model.layers.30': 1, 
                 'model.layers.31': 1, 
                 'model.layers.32': 1, 
                 'model.layers.33': 1, 
                 'model.layers.34': 1, 
                 'model.layers.35': 1, 
                 'model.layers.36': 1, 
                 'model.layers.37': 1, 
                 'model.layers.38': 1, 
                 'model.layers.39': 1, 
                 'model.norm': 1, 
                 'model.rotary_emb': 1, 
                 'lm_head': 1}, # Use only the first GPU
    trust_remote_code=True,
    torch_dtype=torch.float16, # Load model in bfloat16 for better performance
    low_cpu_mem_usage=True, # Reduce CPU memory usage during loading
    max_memory={0: "30GB", 1:"23.5GB"}  # Limit GPU memory usage to 20GB
)

Loading checkpoint shards: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6/6 [00:06<00:00,  1.07s/it]


In [5]:
# Check current device mapping
print("Current device map:", model.hf_device_map)

Current device map: {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0, 'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.layers.32': 1, 'model.layers.33': 1, 'model.layers.34': 1, 'model.layers.35': 1, 'model.layers.36': 1, 'model.layers.37': 1, 'model.layers.38': 1, 'model.layers.39': 1, 'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1}


In [6]:
# Add epitope tokens
epitope_tokens = ["<epi>", "</epi>"]
tokenizer.add_special_tokens({"additional_special_tokens": epitope_tokens})



2

In [7]:
# Add amino acid tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["|"]
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
tokenizer.add_tokens(new_tokens)

# Resize model embeddings ONCE after adding all tokens
model.resize_token_embeddings(len(tokenizer))
model.train()


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(100354, 5120, padding_idx=100349)
    (layers): ModuleList(
      (0-39): 40 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (qkv_proj): Linear(in_features=5120, out_features=7680, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=5120, out_features=35840, bias=False)
          (down_proj): Linear(in_features=17920, out_features=5120, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((5120,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((5120,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=5120, out_features=1003

In [8]:
# Convert epitope format function
import re
def convert_epitope_format(sequence):
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

## NOW create dataset with all tokens available
def format_prompt(example):
    epitope_seq = convert_epitope_format(example['highlighted_epitope_seqs'])
    return {
        "text": f"Antigen: {epitope_seq}<|im_end|>\nAntibody: {example['antibody_fv_seqs']}<|im_end|>\n"
    }

dataset = Dataset.from_pandas(df)
dataset = dataset.map(format_prompt)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9523/9523 [00:00<00:00, 17577.79 examples/s]


In [11]:
# Add task-specific tokens
task_tokens = ["Antigen", "Antibody", "Epitope"]
tokenizer.add_tokens(task_tokens)
model.resize_token_embeddings(len(tokenizer))

 

Embedding(100357, 5120, padding_idx=100349)

In [12]:
# Check truncation at 800
sequence_lengths = [len(tokenizer(example["text"], truncation=False)["input_ids"]) for example in dataset]
truncated_800 = sum(1 for length in sequence_lengths if length > 800)
print(f"Sequences truncated at max_length=800: {truncated_800}/{len(sequence_lengths)} ({100*truncated_800/len(sequence_lengths):.1f}%)")

Sequences truncated at max_length=800: 198/9523 (2.1%)


In [13]:
## Tokenize the dataset
def tokenize(example):
    encoded = tokenizer(example["text"], truncation=True, max_length=800)
    # Make sure labels are a proper list, not nested
    encoded["labels"] = encoded["input_ids"].copy()
    return encoded

tokenized_dataset = dataset.map(tokenize)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9523/9523 [00:02<00:00, 3635.71 examples/s]


In [14]:
# Verify tokenization is working with epitope tokens
print("Sample tokenized text:")
sample_tokens = tokenizer.tokenize(dataset[0]['text'][:200])
print(sample_tokens)

Sample tokenized text:
['Antigen', ':', 'Ä SCN', 'GL', 'YY', 'Q', 'G', 'SC', 'Y', 'I', '<epi>', 'L', '</epi>', 'H', 'SD', '<epi>', 'Y', '</epi>', 'K', 'SF', 'ED', 'AK', 'AN', 'CAA', 'ES', 'ST', 'LP', 'NK', 'SD', 'VL', 'TT', 'W', 'LI', '<epi>', 'D', '</epi>', '<epi>', 'Y', '</epi>', 'V', '<epi>', 'E', '</epi>', '<epi>', 'D', '</epi>', '<epi>', 'T', '</epi>', 'WG', 'SD', 'GN', 'P', 'IT', 'K', 'TT', 'SD', '<epi>', 'Y', '</epi>', 'Q', 'DS', '<epi>', 'D', '</epi>', 'VS', '<epi>', 'Q', '</epi>', '<epi>', 'E']


In [15]:
# Remove unnecessary columns
tokenized_dataset = tokenized_dataset.remove_columns([
    'pdb_id', 'h_chain_id', 'l_chain_id', 'antigen_ids', 'antigen_seqs',
    'h_chain_seq', 'l_chain_seq', 'antibody_seqs',
    'highlighted_epitope_seqs', 'epitope_residues','h_chain_fv_seq',
       'l_chain_fv_seq', 'antibody_fv_seqs', 'text'
])
print("Columns after removal:", tokenized_dataset.column_names)
# Should show: ['input_ids', 'attention_mask', 'labels']


Columns after removal: ['input_ids', 'attention_mask', 'labels']


In [16]:
# Apply the gradient fix to your model
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()
else:
    def make_inputs_require_grad(module, input, output):
        output.requires_grad_(True)
    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

In [17]:
# # Create data collator
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer,
#     mlm=False,
#     return_tensors="pt",
#     pad_to_multiple_of=8, # Pad to multiple of 8 for better performance on GPUs
# )
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    return_tensors="pt",
)

In [18]:
 # Configure LoRA
## PEFT configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
    target_modules=["o_proj", "qkv_proj"],
)

# Apply LoRA to the model
model = get_peft_model(model, peft_config)

# Print trainable parameters
model.print_trainable_parameters()



trainable params: 7,372,800 || all params: 14,666,931,200 || trainable%: 0.0503


In [19]:
# Update training arguments to enable wandb logging
training_args = TrainingArguments(
    output_dir=f"../models/peleke-{model_name.split('/')[-1]}-0806025",
    per_device_train_batch_size=9,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=6,
    num_train_epochs=3,
    warmup_steps=25,
    weight_decay=0.01,
    learning_rate=2e-4,
    logging_dir="../logs",
    logging_steps=25,
    gradient_checkpointing=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none", #"wandb",  # Enable wandb reporting
    run_name=f"lora-epitope-{model_name.split('/')[-1]}",  # Run name for wandb
    # optim="adamw_torch",
    fp16=True,  # Enable mixed precision training
    dataloader_num_workers=8,  # Add parallel data loading
    dataloader_pin_memory=True,  # Pin memory for faster data loading
    remove_unused_columns=False,
    max_grad_norm=1.0,
)

In [20]:
import re

def convert_brackets_to_epi(sequence):
    """Convert [X] format to <epi>X</epi> format"""
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

# Convert your bracket sequences to the training format
sequences_with_brackets = [
    "KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL",
    "NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI[R]G[N]EV[S][Q]IAPGQ[T]GNIADYNYKLPDDFTGCVIAWNSN[K]LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF[R]P[T][Y][G][V]GH[Q]PYRVVVLSFELLHAPATVCGP",
]

# Convert to the exact training format
test_antigens = [convert_brackets_to_epi(seq) for seq in sequences_with_brackets]

# Verify the conversion
for i, (orig, conv) in enumerate(zip(sequences_with_brackets, test_antigens)):
    print(f"=== Sequence {i+1} ===")
    print(f"Original: {orig[:60]}...")
    print(f"Converted: {conv[:60]}...")
    print("-" * 60)

print(f"\nFinal test_antigens for training callback:")
for i, antigen in enumerate(test_antigens):
    print(f"Test {i+1}: {antigen[:80]}...")




=== Sequence 1 ===
Original: KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATN...
Converted: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
------------------------------------------------------------
=== Sequence 2 ===
Original: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Converted: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
------------------------------------------------------------

Final test_antigens for training callback:
Test 1: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<ep...
Test 2: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI<epi>R</epi...


In [22]:
from transformers import TrainerCallback
import torch
from datetime import datetime
import os

class TestGenerationCallback(TrainerCallback):
    def __init__(self, model, tokenizer, test_antigens, log_every_n_steps=100, output_file="test_generations.txt"):
        self.model = model
        self.tokenizer = tokenizer
        self.test_antigens = test_antigens
        self.log_every_n_steps = log_every_n_steps
        self.output_file = output_file
        
        # Create/clear the output file
        with open(self.output_file, 'w') as f:
            f.write(f"Test Generation Log - Started: {datetime.now()}\n")
            f.write("="*80 + "\n\n")
    
    def create_test_prompt(self, antigen_with_epitopes):
        return f"Antigen: {antigen_with_epitopes}<|im_end|>\nAntibody:"
    
    def generate_antibody_test(self, antigen_with_epitopes, max_length=800):
        """Generate antibody for testing during training"""
        prompt = self.create_test_prompt(antigen_with_epitopes)
        
        # Tokenize
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Generate
        self.model.eval()
        try:
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=200,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
                    repetition_penalty=1.1,
                )
                
                # Decode and extract antibody
                generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
                if "Antibody:" in generated_text:
                    antibody_part = generated_text.split("Antibody:", 1)[1]
                    if "<|im_end|>" in antibody_part:
                        antibody_sequence = antibody_part.split("<|im_end|>", 1)[0].strip()
                    else:
                        antibody_sequence = antibody_part.strip()
                else:
                    antibody_sequence = "Generation failed"
                
                return antibody_sequence
                
        except Exception as e:
            return f"Error: {str(e)}"
        finally:
            self.model.train()  # Put model back in training mode
    
    def run_test_generation(self, state, phase="TRAINING"):
        """Run test generation and print/save results"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        header = f"TEST GENERATION - {phase} - STEP {state.global_step} - {timestamp}"
        
        # Print to terminal
        print(f"\n{'='*80}")
        print(header)
        print(f"{'='*80}")
        
        # Write to file
        with open(self.output_file, 'a') as f:
            f.write(f"\n{'='*80}\n")
            f.write(f"{header}\n")
            f.write(f"{'='*80}\n")
        
        for i, test_antigen in enumerate(self.test_antigens):
            case_header = f"--- Test Case {i+1} ---"
            input_display = f"Input: {test_antigen[:60]}{'...' if len(test_antigen) > 60 else ''}"
            
            # Generate antibody
            antibody = self.generate_antibody_test(test_antigen)
            generated_display = f"Generated: {antibody}"
            
            # Print to terminal
            print(f"\n{case_header}")
            print(input_display)
            print(generated_display)
            
            # Write to file (with full input)
            with open(self.output_file, 'a') as f:
                f.write(f"\n{case_header}\n")
                f.write(f"Full Input: {test_antigen}\n")
                f.write(f"Generated: {antibody}\n")
                f.write(f"Length: {len(antibody)} characters\n")
        
        # Terminal footer
        print(f"{'='*80}\n")
        
        # File footer
        with open(self.output_file, 'a') as f:
            f.write(f"{'='*80}\n\n")
    
    def on_train_begin(self, args, state, control, **kwargs):
        """Test at the beginning of training"""
        print("ðŸ§¬ INITIAL GENERATION TEST (Before Training)")
        self.run_test_generation(state, "INITIAL")
    
    def on_log(self, args, state, control, **kwargs):
        """Test periodically during training"""
        if state.global_step % self.log_every_n_steps == 0 and state.global_step > 0:
            self.run_test_generation(state, "PERIODIC")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Test at the end of training"""
        print("ðŸŽ‰ FINAL GENERATION TEST (After Training)")
        self.run_test_generation(state, "FINAL")
        
        # Add summary to file
        with open(self.output_file, 'a') as f:
            f.write(f"\nTraining completed: {datetime.now()}\n")
            f.write(f"Final step: {state.global_step}\n")


In [23]:


# Create the callback
test_callback = TestGenerationCallback(
    model=model, 
    tokenizer=tokenizer, 
    test_antigens=test_antigens,
    log_every_n_steps=50,  # Test every 50 steps
    output_file="/home/nicholas/Documents/GitHub/peleke/logs/test_generations.txt"  # Save to logs directory
)




In [None]:
# Set up SFTTrainer
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    args=training_args,
    callbacks=[test_callback],  # Log every 50 steps
)

# Start training
print("Starting training...")
trainer.train()

# Finish wandb run
#wandb.finish()

# Save the trained model
trainer.save_model()
print(f"Model saved to {training_args.output_dir}")

Truncating train dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9523/9523 [00:00<00:00, 167215.19 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Starting training...
ðŸ§¬ INITIAL GENERATION TEST (Before Training)

TEST GENERATION - INITIAL - STEP 0 - 2025-08-06 11:15:22

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: Greetings! I'm here to assist you with any questions or topics you might have. What can I help you with today? Whether it's about a specific subject, general knowledge, advice, or just a friendly chat, feel free to ask anything!

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI AccessControlList { 
    Effect           = "Allow" 
    Principal        = "*" 
    Action           = [ 
                        "s3:GetObject", 
                        "s3:PutObject"
                    ] 
    Resource         = "arn:aws:s3::${var.bucket_name}/*"
}

# Create an S3 bucket
resource "aws_s3_bucket" "my_bucket" {
  bucket = var.bucket_name

  acl    

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Step,Training Loss
25,4.9024
50,4.3579
75,4.2579
100,3.9475
125,3.6464
150,3.6204
175,3.5222
200,3.6167
225,3.7327
250,3.3504



TEST GENERATION - PERIODIC - STEP 50 - 2025-08-06 11:20:46

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QVQLQQSGAELVRPGASVKLSCTASGFNIHWYQQKPGKAPKLLIYSADTAVYYCARLDPFGGGTKVEIK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: VQLVESGGGLVQPGRSLRLSCAASGFTFSNYAMSWVRQAPGKGLEWIGEIIPGYEQYVQPGQSPKLLIYDASSLRSSDTAMHWVKQRPGQGLEWMGWINPYNGYNYTSKSRIINPKNTLVTSFGDRVTITCQASQSIISSYEDLTIYYCQQSSGVTVLTQPEDFEWIFMSHDFTLTSKISVEKMTVDYRFWGQGTLVTVSS|ELSVAYSGSTVKAISCKGHTVTVSRASQGIHHYTHSEFVQTKISCTGCSGTDFVCPARVGSDYTQTVSISCSLAEYSVTPVYDVVNQKEFKRDYQLSVSPLTFGGGTKLEIK


TEST GENERATION - PERIODIC - STEP 100 - 2025-08-06 11:25:44

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGVVKPGSGSLRLSCAASGFTFSSYMWWMRWIRQPPGKGLEWIGWIYYDGSTYYADSVKDYEYLDSWGHGFRTFSNFGDDVTVS|DIQMTQSPSSLSASVGDRVTITCRASESVSSSWYSAYVHWYQQKPGKAPKLILPDLPDPPEPEDSPTFGGGTKLTVL

--- Test Cas




TEST GENERATION - PERIODIC - STEP 550 - 2025-08-06 12:13:22

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGLIQPGGSLRLSCAASEFTISKFMHWVRQAPGKGLEWVASISSYSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRVEDTAVYYCARDDYYDVWGQGTLVTV|DIQMTQSPSTLSASVGDRVTITCRASKSYAYYWSWYQQKPGKAPKLMIYKVSNRFSGVPSRFSGSRSGTDFTLTINNVQPEDFATYYCQQSQSYPLTFGAGTKLELK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: EVQLVESGGGLVKPGGSLRLSCAASGYEFIFNWMSWVRQAPGKGLEWVAIIWDGSGDTYYADSVGRFTISRDNAKKMFYLELRAEDTAVYYCAKQGKYWGQGTLVTVSS|DIQMTQSPSSLSASVGDRVTITCRASQSISSWLAWYQQKPGKAPKLLIYAASTLQSGVPSRFSGSRSGTDFTLTINSLQPEDFATYYCQQSNFPYTFGQGTKVEIK


TEST GENERATION - PERIODIC - STEP 600 - 2025-08-06 12:18:34

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: VQLVESGGGVVQPGRSLRLSCAASGFNIFTDYGMHWVRQAPGKGLEWVASISSYYGYTTYYADSVKGRFTISRDNSKSLSLQMRAEDTAVYYCARERVVQDIWGQGTSVTVSS|EIVLTQSPGTLSLSP




TEST GENERATION - PERIODIC - STEP 1050 - 2025-08-06 13:06:12

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGLVKPGGSLKLSCAASGFIFSSYWMHWVRQTPEKRLEWVASISNSGGYTYYADSVKGRFTISRDNAKNSLYLQMRAEDTAIYYCARDPPLGSDWGQGTAVTVSS|DIVMTQSPSSLTVSVGDRVTITCRASEDIYSNLAWYQQKPGKAPKLLIYKTSTLASSGVPSRFSGSGSGTEFTLTISRLEPEDFAVYYCQQYDNWPRTFGQGTKVEIK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: EVQLVESGGGLVKPGGSLRLSCAASGFTFRNYAMSWVRQAPGKGLEWVSAISSSGGSTYYADSVKGRFTISRDNAKNSLYLQMRAEDTAVYYCARDGPYYYGYFAVWGAGTTVTVSS|DIQMTQSPSSLSASVGDRVTITCRASQGISSYLAWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQRSFPVTFGGGTKVEIK



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


TEST GENERATION - PERIODIC - STEP 1100 - 2025-08-06 13:11:29

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QLQLQQPGAELVKPGASVKLSCTASGFNIKDYYMTWVKQRPEQGLEWIGRIDPANGHTNYNEKFKNRVTLTADKSSSTAYMQLSSLASEDSAVYYCARERGDGYAMDYWGQGTTLTVSS|DIQMTQSPSSVSASVGDRVTITCRASQSISSWLAWYQQKPGRAPKLLISAASTLQSGVPSRFSGSGSGTEFTLTISTLRPEDFATYYCQQSYEDPYTFGGGTKVEIK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: QVQLVESGGGVVQPGRSLRLSCAASGFTFSNYGMHWVRQAPDKGLEWVALIKSGGTSAKYDTSVKGRFTISRDNAKNTLYLEMSSLRSEDTAMYYCARRRGYYYAYWGQGTLVTVSA|DIQMTQSPSTLAASPAAVTTINCPGSQQTYLNWLTQRASESIIYWYRKNPGPRPPRRLIYRGAISIRRFSGSDRRSASIDGYTNQPEDEAIYYCMHFWSNHTPVFGAGTKLEIK


TEST GENERATION - PERIODIC - STEP 1150 - 2025-08-06 13:16:44

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QLQLQQSGAELVRPGTSVKLSCKASEYTFTNYGMNWVKQRPEQGLEWIGRIYPGDGDTNYNEKFQKFKDKATLTADKSSSTAYMQLSSLTSEDSAVYYCARERFD




TEST GENERATION - PERIODIC - STEP 1550 - 2025-08-06 13:58:34

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGVVQPGRSLRLSCAASGFTFRNYGMHWVRQAPGKGLEWVAFIRYDGGNKYYADSVKGRFIISRDNSKNTLYLQMRAEDTAVYYCARHVLDDFDIWGQGTLVTV|ALTQPPSVSGSPGQSVTISCTGTSSDVGGYNYVSWYQQHPGKAPKLMIYEVSKRPSGVSNRFSGSKSGNTASLTISGLQAEDEADYYCSSYTTSSTWVFGGGTKLTV

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: QMQLVQSGPEVKKPGTSVKVSCKASGFTFSRYAMSWVRQAPGRGLEWMGWIFTSGTINYAQNFQGRVTITADRSTSTAYLELRSEDTAVYYCAKHGDYDYDSSWGQGTLVTVSS|EIVLTQSPGTLSLSPGERATLSCRASQSVDYLGWYQQKRGQEPSPRLLIKYASESISKSRSGIPSRFSGSGFGTDFTLTISRLEPEDFAVYYCQQYGSSPWTFGQGTKVEI


TEST GENERATION - PERIODIC - STEP 1600 - 2025-08-06 14:03:42

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QLVESGGGVVKPGGSLKLSCAASGFTFSNYGMNWVRQTPEKRLEWVASISDGGSNNYNYPDKFKGKATLTADTSSSTAYMELSSLTSEDSAVYYCARERDFGDYWGQGTLVTVSS|DI




TEST GENERATION - PERIODIC - STEP 2050 - 2025-08-06 14:51:07

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QVQLVESGGGVVQPGRSLRLSCAASGFTFSNYGMHWVRQAPGKGLEWVASISSSGSTYYADSVKGRFTISRDNSKNTLYLQMRAEDTAVYYCARPPDYWGQGTLVTVSS|DIQMTQSPSSLSASVGDRVTITCRASKGIYSNLAWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPRTFGQGTKVEIK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: QMQLVESGGGVVQPGRSLRLSCAASGFPFSSYGMSWVRQAPGKGLEWLGLIGIWNNHSNRYYADSVKGRFTISRDNARNTLYLQMNSLRPEDTAVYYCAREGSGTYGNWGQGTTLTVSS|DIQMTQSPSTLSASVGDRVTITCRASSIRSNFLNWYQQKPGKAPKLLIYDASNLETGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQQRNNLPRTFGPGTKVDIK


TEST GENERATION - PERIODIC - STEP 2100 - 2025-08-06 14:56:26

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGVVRPGGSLRLSCAASAFTFNNYMHWVRQAPGKGLEWVAYISSSGSTYYADSVKGRFTISRDNSKNTLYLQMRAEDTAVYYCARDLPYYYDDIWGQGTLVTVSS|SYELT

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


TEST GENERATION - PERIODIC - STEP 2150 - 2025-08-06 15:01:41

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: EVQLVESGGGLVKPGGSLKLSCAASGFTFSSYTMNWVRQSPEKRLEWVASISSSGSTYYTDSVKGRFTISRDNAKNSLYLQMRAEDTAVYYCARERDYWGQGTLVTVSS|IQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQSSSSLITFGQGTKVEIK

--- Test Case 2 ---
Input: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFT...
Generated: VQLVESGGGLIQPGGSLRLSCAASGVTVSRNYMSWVRQAPGKGLEWVSVMFSGGSTFYADSVKGRFTISRDNSKNTLYLQMRAEDTAVYYCARDLRDYYGDVWGQGTTVT|IVLTQSPGTLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPRLLIYGASSRAPGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQFGDSPRTFGQGTKVE


TEST GENERATION - PERIODIC - STEP 2200 - 2025-08-06 15:06:32

--- Test Case 1 ---
Input: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</e...
Generated: QVQLQQPGAELVKPGASVKLSCKASGYTFTSYNMHWVRQAPGQRLEWMGWIKGSGSTNYAQKFQDWVTMTRDTSTTTAYMELRSDDTAIYYCARDPFHYDFWGQGTTLTVSS|EIVLTQSPGTLSLSPGERAT

### This script is used to clear vram. For testing purposes only and when you want to clear the GPU memory.

In [1]:
import gc
import torch
# Clear any existing models from GPU memory
torch.cuda.empty_cache()
gc.collect()

# Check current GPU memory usage
print(f"GPU Memory before: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB allocated")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")
# If you have a model loaded, delete it first
try:
    del model
    torch.cuda.empty_cache()
    gc.collect()
    print("Previous model cleared from memory")
except:
    print("No previous model to clear")

GPU Memory before: 0.00 GB allocated
GPU Memory reserved: 0.00 GB reserved
No previous model to clear
