# Fine-Tuning Microsoft's Dayhoff Model for Antibody Generation

In [1]:
# conda create --name dayhoff python=3.10
# conda activate dayhoff
# pip install dayhoff peft trl
# conda install ipykernel

In [2]:
## To avoid "device-side assert triggered" RuntimeErrors
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['TORCH_USE_CUDA_DSA'] = '1'
# os.environ["MAMBA_DISABLE_FAST_KERNELS"] = "1"

In [3]:
import torch
import time
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    set_seed,
    TrainerCallback,
    DataCollatorForSeq2Seq,
)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
## Config
MODEL_ID = "microsoft/Dayhoff-3b-UR90"
DATA_PATH = "../data/sabdab/sabdab_training_dataset.csv"  # CSV with columns: antigen, antibody
OUTPUT_DIR = "models/peleke-dayhoff-3b-ur90_2.01"
#CACHE_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/colby-h100-01-ci/code/.cache/huggingface/"

MAX_LEN = 3072 # old value2048
BATCH_SIZE = 1 # old value 2
GRAD_ACCUM = 8 # old value 8
EPOCHS = 3
LR = 5e-5 # old value 2e-4
WARMUP = 0.03
USE_QLORA = True
SEED = 1337
DEVICE = torch.device("cuda")

In [5]:
## Set Seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
set_seed(SEED)

In [6]:
## Tokenizer Setup
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    #cache_dir=CACHE_DIR
    )

## Add epitope tokens
# epitope_tokens = ["[", "]"]
# tokenizer.add_special_tokens({"additional_special_tokens": epitope_tokens})
# special_tokens_dict = {"additional_special_tokens": ['|']}
# num_added = tokenizer.add_special_tokens(special_tokens_dict)

## Add amino acid tokens (which should all exist) and delimiter
# amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
# extra_tokens = amino_acids + ["|"]
# new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
# tokenizer.add_tokens(new_tokens)

# num_added = tokenizer.add_tokens(['|'])


bnb_config = None
# if USE_QLORA:
#     from transformers import BitsAndBytesConfig
#     bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_use_double_quant=True,
#         bnb_4bit_compute_dtype=torch.bfloat16,
#     )

## Confirm existing special tokens
print("Special tokens:", tokenizer.special_tokens_map)

# print("Num tokens added:", num_added)
# print("New tokenizer size:", len(tokenizer))
# print("ID for |:", tokenizer.convert_tokens_to_ids("|"))

Special tokens: {'bos_token': '@', 'eos_token': '*', 'sep_token': '/', 'pad_token': '!', 'mask_token': '#'}


In [7]:
## Model Setup
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    # torch_dtype=torch.bfloat16,
    torch_dtype="auto",
    # device_map="auto",
    device_map="auto",
    # quantization_config=bnb_config,
    trust_remote_code=True,
    use_mamba_kernels=False,
    #cache_dir=CACHE_DIR
)#.cuda()

## Resize model to accomodate new tokens
model.resize_token_embeddings(len(tokenizer))

if USE_QLORA:
    # model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        target_modules="all-linear"
    )
    model = get_peft_model(model, peft_config)#.cuda()

#model.to(DEVICE)
# model.train()

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config
Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.67it/s]


In [8]:
## Load Dataset (your current code)
dataset = load_dataset("csv", data_files=DATA_PATH)
all_columns = dataset["train"].column_names
columns_to_remove = [col for col in all_columns if col not in ['antibody_fv_seqs', 'highlighted_epitope_seqs']]
dataset = dataset.remove_columns(columns_to_remove) \
    .rename_column("antibody_fv_seqs", "antibody") \
    .rename_column("highlighted_epitope_seqs", "antigen")

## Pre-tokenization: Check sequence lengths BEFORE truncation
def analyze_sequence_lengths(dataset, max_len=2048):
    """Analyze sequence lengths to determine optimal max_length"""
    sequence_lengths = []
    
    for example in dataset:
        # Build the text format (same as your build_example function)
        allowed_chars = "ACDEFGHIKLMNPQRSTVWY[]@*/#"
        antigen = "".join([c for c in example['antigen'] if c in allowed_chars])
        antibody = "".join([c for c in example['antibody'].replace('|','/') if c in allowed_chars])
        text = f"@{antigen}/{antibody}*"
        
        # Tokenize without truncation to get true length
        tokens = tokenizer(text, add_special_tokens=False, truncation=False)
        sequence_lengths.append(len(tokens["input_ids"]))
    
    # Calculate statistics
    import numpy as np
    lengths = np.array(sequence_lengths)
    
    print(f"Sequence Length Analysis:")
    print(f"Total sequences: {len(lengths)}")
    print(f"Mean length: {np.mean(lengths):.1f}")
    print(f"Median length: {np.median(lengths):.1f}")
    print(f"Max length: {np.max(lengths)}")
    print(f"Min length: {np.min(lengths)}")
    
    # Check truncation at different max lengths
    for test_max_len in [1024, 1536, 2048, 2560, 3072]:
        truncated = np.sum(lengths > test_max_len)
        pct_truncated = 100 * truncated / len(lengths)
        pct_retained = 100 - pct_truncated
        print(f"At max_length={test_max_len}: {truncated}/{len(lengths)} truncated ({pct_truncated:.1f}%) | {pct_retained:.1f}% data retained")
    
    return sequence_lengths

# Run analysis on your train split
print("=== SEQUENCE LENGTH ANALYSIS ===")
sequence_lengths = analyze_sequence_lengths(dataset['train'])

Token indices sequence length is longer than the specified maximum sequence length for this model (2269 > 2048). Running this sequence through the model will result in indexing errors


=== SEQUENCE LENGTH ANALYSIS ===
Sequence Length Analysis:
Total sequences: 9523
Mean length: 631.4
Median length: 542.0
Max length: 3560
Min length: 224
At max_length=1024: 1409/9523 truncated (14.8%) | 85.2% data retained
At max_length=1536: 92/9523 truncated (1.0%) | 99.0% data retained
At max_length=2048: 41/9523 truncated (0.4%) | 99.6% data retained
At max_length=2560: 3/9523 truncated (0.0%) | 100.0% data retained
At max_length=3072: 3/9523 truncated (0.0%) | 100.0% data retained


In [9]:
def build_example(example):
    # Clean the sequences
    allowed_chars = "ACDEFGHIKLMNPQRSTVWY[]@*/#"
    antigen = "".join([c for c in example['antigen'] if c in allowed_chars])
    antibody = "".join([c for c in example['antibody'].replace('|','/') if c in allowed_chars])
    
    # Format: @ANTIGEN/ANTIBODY*
    text = f"@{antigen}/{antibody}*"
    
    # Check length before truncation (optional - for monitoring)
    full_tokens = tokenizer(text, add_special_tokens=False, truncation=False)
    original_length = len(full_tokens["input_ids"])
    
    # Tokenize with your settings
    enc = tokenizer(
        text,
        max_length=MAX_LEN,  # Use the optimal length from analysis
        truncation=True,
        padding="max_length",
        return_tensors=None,
        add_special_tokens=False,
    )
    
    # Create labels with masking (your current approach)
    prompt_ids = tokenizer(text, add_special_tokens=False, truncation=True, max_length=MAX_LEN)["input_ids"]
    seq_len = len(enc["input_ids"])
    labels = [-100] * seq_len
    
    # Only label the actual sequence, not padding
    for i in range(len(prompt_ids)):
        if i < seq_len and enc["input_ids"][i] != tokenizer.pad_token_id:
            labels[i] = enc["input_ids"][i]
    
    enc["labels"] = labels
    
    # Optional: track truncation
    if original_length > MAX_LEN:
        enc["was_truncated"] = True
        enc["original_length"] = original_length
    
    return enc

In [10]:
# 1. Analyze sequence lengths first
sequence_lengths = analyze_sequence_lengths(dataset['train'])

# 2. Choose optimal MAX_LEN based on analysis (e.g., for 95% retention)
MAX_LEN = 2048  # Adjust based on your analysis results

# 3. Apply tokenization
tokenized = dataset.map(build_example, remove_columns=dataset['train'].column_names)

# 4. Check how many were actually truncated
if 'was_truncated' in tokenized['train'][0]:
    truncated_count = sum(1 for x in tokenized['train'] if x.get('was_truncated', False))
    print(f"Actually truncated during tokenization: {truncated_count}/{len(tokenized['train'])} ({100*truncated_count/len(tokenized['train']):.1f}%)")

print("Final tokenized dataset:")
print(tokenized)

Sequence Length Analysis:
Total sequences: 9523
Mean length: 631.4
Median length: 542.0
Max length: 3560
Min length: 224
At max_length=1024: 1409/9523 truncated (14.8%) | 85.2% data retained
At max_length=1536: 92/9523 truncated (1.0%) | 99.0% data retained
At max_length=2048: 41/9523 truncated (0.4%) | 99.6% data retained
At max_length=2560: 3/9523 truncated (0.0%) | 100.0% data retained
At max_length=3072: 3/9523 truncated (0.0%) | 100.0% data retained
Final tokenized dataset:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9523
    })
})


In [11]:
test_antibody = "EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLEWVASIYSYYGSTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCAREYHSYWSYSWWPRVGLDYWGQGTLVTVSS|AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQASLTALLTFGQGTKVEIK"
test_antigen = "KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEKFPQVAATGDGPDIIFWAHDRFGGYAQSGLLAEITPDKAFQDKL[Y][P][F]TW[D][A]VRYN[G]KLIAYPIAVEALSLIYNKDLLPNPPKTWEEIPALDKELKAKGKSALMFNLQEPYFTWPLIAADGGYAFK[Y]EN[G][K][Y]DIKDVGVDNAGAKAGLTFLVDLIKNKHMNADTDYSIAEAAFNKGETAMTINGPWAWSNIDTSKVNYGVTVLPTFKGQPSKPF[V]GVLSAGINAASPNKELAKEFLENYLLTDEGLEAVNKDKPLGAVALKSY[E][E]EL[A]KDPR[I][A]AT[M][E]N[A][Q][K][G][E][I]M[P]NIPQMSAFWYAVRTAVINAASGRQTVDEALKDAQTIIELYRQSLEIISRYLREQATGAADTAPMGATSRKALETLRRVGDGVQRNHETAFQGMLRKLDIKNEDDVKSLSRVMIHVFSDGVTNWGRIVTLISFGAFVAKHLKTINQESAIEPLAESITDVLVRTKRDWLVKQRGWDGFVEFF"

example = {
    'antigen': test_antigen,
    'antibody': test_antibody
}

ex = build_example(example)
print("Max input id:", max(ex["input_ids"]))
print("Max label id:", max([i for i in ex["labels"] if i != -100]))
print("Vocab size:", len(tokenizer))

Max input id: 33
Max label id: 33
Vocab size: 36


In [12]:
 
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=8
)

 

In [13]:
## Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    label_names=["labels"],
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    warmup_ratio=WARMUP,
    logging_steps=25,
    save_strategy="epoch",
    fp16=True, # old code torch.cuda.is_available(),
    # bf16=torch.cuda.is_available(),
    optim="adamw_torch", # old code --> "paged_adamw_8bit" if USE_QLORA else "adamw_torch",
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,
    report_to="none",
    seed=int(SEED),
    dataloader_num_workers=8,        # Increase workers
    dataloader_pin_memory=False,  # Try False first
    # no_cuda=True ## To Test
)

In [14]:
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized['train'],
#     data_collator=data_collator,
#     # data_collator=sft_collator,
# )

 
# ## Set up SFTTrainer
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized["train"],
    # data_collator=sft_collator,
    data_collator=data_collator,
    args=training_args,
    # callbacks=[test_callback],  # Log every 50 steps
)


In [15]:
for batch in trainer.get_train_dataloader():
    # print(batch.keys())      # e.g., dict_keys(['input_ids', 'attention_mask', 'labels'])
    # print(batch["input_ids"].shape)
    for key, tensor in batch.items():
        print(key, tensor.device, tensor.dtype)
    break

input_ids cuda:0 torch.int64
attention_mask cuda:0 torch.int64
labels cuda:0 torch.int64


In [16]:
# Get a single batch from your trainer
batch = next(iter(trainer.get_train_dataloader()))

print("=== Batch tensor devices ===")
for key, tensor in batch.items():
    print(f"{key}: device={tensor.device}, dtype={tensor.dtype}, shape={tensor.shape}")

# Check model device
model_device = next(model.parameters()).device
print(f"\nModel device: {model_device}")

=== Batch tensor devices ===
input_ids: device=cuda:0, dtype=torch.int64, shape=torch.Size([1, 1024])
attention_mask: device=cuda:0, dtype=torch.int64, shape=torch.Size([1, 1024])
labels: device=cuda:0, dtype=torch.int64, shape=torch.Size([1, 1024])

Model device: cuda:0


In [17]:
device = next(model.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}


In [None]:
import time

# Before training, check current setup
print("=== SPEED DIAGNOSTICS ===")
print(f"Dataset size: {len(tokenized['train'])}")
print(f"Batch size per device: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
effective_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * torch.cuda.device_count()
print(f"Effective batch size: {effective_batch_size}")
steps_per_epoch = len(tokenized['train']) // effective_batch_size
print(f"Steps per epoch: {steps_per_epoch}")

# Time the training
start_time = time.time()
trainer.train()
end_time = time.time()

total_time = end_time - start_time
print(f"\nTraining completed in: {total_time:.2f} seconds")
print(f"Steps per second: {steps_per_epoch / total_time:.3f}")

=== SPEED DIAGNOSTICS ===
Dataset size: 9523
Batch size per device: 1
Gradient accumulation: 8
Number of GPUs: 2
Effective batch size: 16
Steps per epoch: 595


In [None]:
trainer.train()

In [93]:
next(model.parameters()).dtype

torch.float32

In [None]:
## Save the trained model
trainer.save_model(training_args.output_dir)
print(f"Model saved to {training_args.output_dir}")

In [None]:
## Inference Test
def generate_antibody(antigen: str, max_new_tokens: int = 700):
    prompt = tokenizer.bos_token + antigen + tokenizer.sep_token
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    return text.split(tokenizer.sep_token, 1)[-1].rstrip(tokenizer.eos_token)


test_antigen = "KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL"
print("Generated antibody:", generate_antibody(test_antigen))