# Fine-Tuning Microsoft's Dayhoff Model for Antibody Generation

In [None]:
# conda create --name dayhoff python=3.10
# conda activate dayhoff
# pip install dayhoff peft trl
# conda install ipykernel

In [None]:
## To avoid "device-side assert triggered" RuntimeErrors
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['TORCH_USE_CUDA_DSA'] = '1'
# os.environ["MAMBA_DISABLE_FAST_KERNELS"] = "1"

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    set_seed
)

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [None]:
## Config
MODEL_ID = "microsoft/Dayhoff-3b-UR90"
DATA_PATH = "../data/sabdab/sabdab_training_dataset.csv"  # CSV with columns: antigen, antibody
OUTPUT_DIR = "models/peleke-dayhoff-3b-ur90_2"
CACHE_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/colby-h100-01-ci/code/.cache/huggingface/"

MAX_LEN = 2048
BATCH_SIZE = 2
GRAD_ACCUM = 8
EPOCHS = 3
LR = 2e-4
WARMUP = 0.03
USE_QLORA = True
SEED = 1337
DEVICE = torch.device("cuda")

In [None]:
## Set Seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
set_seed(SEED)

In [None]:
## Tokenizer Setup
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR
    )

## Add epitope tokens
# epitope_tokens = ["[", "]"]
# tokenizer.add_special_tokens({"additional_special_tokens": epitope_tokens})
# special_tokens_dict = {"additional_special_tokens": ['|']}
# num_added = tokenizer.add_special_tokens(special_tokens_dict)

## Add amino acid tokens (which should all exist) and delimiter
# amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
# extra_tokens = amino_acids + ["|"]
# new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
# tokenizer.add_tokens(new_tokens)

# num_added = tokenizer.add_tokens(['|'])


bnb_config = None
# if USE_QLORA:
#     from transformers import BitsAndBytesConfig
#     bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_use_double_quant=True,
#         bnb_4bit_compute_dtype=torch.bfloat16,
#     )

## Confirm existing special tokens
print("Special tokens:", tokenizer.special_tokens_map)

# print("Num tokens added:", num_added)
# print("New tokenizer size:", len(tokenizer))
# print("ID for |:", tokenizer.convert_tokens_to_ids("|"))

In [None]:
## Model Setup
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    # torch_dtype=torch.bfloat16,
    torch_dtype="auto",
    # device_map="auto",
    device_map="auto",
    # quantization_config=bnb_config,
    trust_remote_code=True,
    cache_dir=CACHE_DIR
).cuda()

## Resize model to accomodate new tokens
model.resize_token_embeddings(len(tokenizer))

if USE_QLORA:
    # model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        target_modules="all-linear"
    )
    model = get_peft_model(model, peft_config).cuda()

model.to(DEVICE)
# model.train()

In [None]:
## Load Dataset
## Format: antigen (with [ ] epitope marks), antibody (heavy|light sequence)

## Example row:
## antigen,antibody
## KVFGRCELAAAM[K][R]HGL[D]...GCRL,EVQLVESGG...|DIQMTQSP...

dataset = load_dataset("csv", data_files=DATA_PATH)

## Get all column names
all_columns = dataset["train"].column_names

## Identify columns to remove (keep 'sentence1', 'sentence2', 'label')
columns_to_remove = [col for col in all_columns if col not in ['antibody_fv_seqs', 'highlighted_epitope_seqs']]

## Remove the unwanted columns and rename to friendlier column names
dataset = dataset.remove_columns(columns_to_remove) \
                 .rename_column("antibody_fv_seqs", "antibody") \
                 .rename_column("highlighted_epitope_seqs", "antigen")

dataset

In [94]:
## Build Example Function
def build_example(example):
    ## Clean the sequences
    allowed_chars = "ACDEFGHIKLMNPQRSTVWY[]@*/#"

    antigen = "".join([c for c in example['antigen'] if c in allowed_chars])
    antibody = "".join([c for c in example['antibody'].replace('|','/') if c in allowed_chars])

    ## Format: @ANTIGEN/ANTIBODY*
    text = f"@{antigen}/{antibody}*"

    enc = tokenizer(
        text,
        max_length=MAX_LEN,
        truncation=True,
        padding="max_length",
        return_tensors=None,
        add_special_tokens=False,
    )

    # enc["labels"] = enc["input_ids"].copy()

    ## Mask prompt for loss computation
    prompt_ids = tokenizer(text, add_special_tokens=False)["input_ids"]
    seq_len = len(enc["input_ids"])
    labels = [-100] * seq_len
    for i in range(len(prompt_ids), seq_len):
        labels[i] = enc["input_ids"][i]
    enc["labels"] = labels
    return enc

In [None]:
# def build_example(example):
#     allowed_chars = "ACDEFGHIKLMNPQRSTVWY[]@*/#"

#     antigen = "".join([c for c in example['antigen'] if c in allowed_chars])
#     antibody = "".join([c for c in example['antibody'].replace('|','/') if c in allowed_chars])

#     # construct the full sequence with start token, epitope markers, and chain delimiter
#     text = f"@{antigen}/{antibody}*"

#     enc = tokenizer(
#         text,
#         truncation=True,
#         padding="max_length",
#         max_length=MAX_LEN
#     )

#     # causal LM labels = input_ids
#     enc["labels"] = enc["input_ids"].copy()

#     # ensure attention_mask exists
#     if "attention_mask" not in enc:
#         enc["attention_mask"] = [1] * len(enc["input_ids"])

#     return enc

In [95]:
## Map formatting into tokenized dataset
# tokenized = dataset.map(build_example, remove_columns=dataset.column_names)
tokenized = dataset.map(build_example, remove_columns=dataset['train'].column_names)
# tokenized = tokenized.remove_columns(['token_type_ids'])
tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9523
    })
})

In [96]:
test_antibody = "EVQLVESGGGLVQPGGSLRLSCAASGFNLSSSSIHWVRQAPGKGLEWVASIYSYYGSTSYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCAREYHSYWSYSWWPRVGLDYWGQGTLVTVSS|AQMTQSPSSLSASVGDRVTITCRASQSVSSAVAWYQQKPGKAPKLLIYSASSLYSGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCQQASLTALLTFGQGTKVEIK"
test_antigen = "KIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEKFPQVAATGDGPDIIFWAHDRFGGYAQSGLLAEITPDKAFQDKL[Y][P][F]TW[D][A]VRYN[G]KLIAYPIAVEALSLIYNKDLLPNPPKTWEEIPALDKELKAKGKSALMFNLQEPYFTWPLIAADGGYAFK[Y]EN[G][K][Y]DIKDVGVDNAGAKAGLTFLVDLIKNKHMNADTDYSIAEAAFNKGETAMTINGPWAWSNIDTSKVNYGVTVLPTFKGQPSKPF[V]GVLSAGINAASPNKELAKEFLENYLLTDEGLEAVNKDKPLGAVALKSY[E][E]EL[A]KDPR[I][A]AT[M][E]N[A][Q][K][G][E][I]M[P]NIPQMSAFWYAVRTAVINAASGRQTVDEALKDAQTIIELYRQSLEIISRYLREQATGAADTAPMGATSRKALETLRRVGDGVQRNHETAFQGMLRKLDIKNEDDVKSLSRVMIHVFSDGVTNWGRIVTLISFGAFVAKHLKTINQESAIEPLAESITDVLVRTKRDWLVKQRGWDGFVEFF"

example = {
    'antigen': test_antigen,
    'antibody': test_antibody
}

ex = build_example(example)
print("Max input id:", max(ex["input_ids"]))
print("Max label id:", max([i for i in ex["labels"] if i != -100]))
print("Vocab size:", len(tokenizer))

Max input id: 33
Max label id: 30
Vocab size: 36


In [97]:
## Data Collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# from transformers import default_data_collator
# data_collator = default_data_collator

# from transformers import DataCollatorForSeq2Seq
# data_collator = DataCollatorForSeq2Seq(
#     tokenizer=tokenizer,
#     model=model,
#     label_pad_token_id=-100,
#     return_tensors="pt",
# )

def simple_collator(batch):
    input_ids = torch.stack([torch.tensor(x["input_ids"], dtype=torch.long) for x in batch])
    attention_mask = torch.stack([torch.tensor(x["attention_mask"], dtype=torch.long) for x in batch])
    labels = torch.stack([torch.tensor(x["labels"], dtype=torch.long) for x in batch])

    ## Convert input_ids to embeddings
    inputs_embeds = model.get_input_embeddings()(input_ids).to(model.dtype)  # match model dtype

    inputs_embeds = inputs_embeds.to(DEVICE)
    attention_mask = attention_mask.to(DEVICE)
    labels = labels.to(DEVICE)


    return {
        "inputs_embeds": inputs_embeds,
        "attention_mask": attention_mask,
        "labels": labels,
    }



# def sft_collator(batch):
#     input_ids = torch.stack([torch.tensor(x["input_ids"], dtype=torch.long) for x in batch])#.to(model.dtype).to(DEVICE)
#     # attention_mask = torch.stack([torch.tensor(x["attention_mask"], dtype=torch.long) for x in batch])
#     labels = torch.stack([torch.tensor(x["labels"], dtype=torch.long) for x in batch]).to(model.dtype).to(DEVICE)

#     # convert input_ids to embeddings and match model dtype
#     inputs_embeds = model.get_input_embeddings()(input_ids).to(model.dtype).to(DEVICE)
#     # attention_mask = attention_mask.to(DEVICE)
#     # labels = labels.to(DEVICE)

#     return {
#         "inputs_embeds": inputs_embeds,
#         # "attention_mask": attention_mask,
#         "labels": labels,
#     }


def sft_collator(batch):
    # Stack CPU tensors
    input_ids = torch.stack([torch.tensor(x["input_ids"], dtype=torch.long) for x in batch])
    # attention_mask = torch.stack([torch.tensor(x["attention_mask"], dtype=torch.long) for x in batch])
    labels = torch.stack([torch.tensor(x["labels"], dtype=torch.long) for x in batch])

    # Convert to embeddings on CPU
    inputs_embeds = model.get_input_embeddings()(input_ids).to(model.dtype)

    return {
        "inputs_embeds": inputs_embeds.to_device(DEVICE),
        # "attention_mask": attention_mask,
        "labels": labels.to_device(DEVICE),
    }




In [98]:
## Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    label_names=["labels"],
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    warmup_ratio=WARMUP,
    logging_steps=25,
    save_strategy="epoch",
    fp16=torch.cuda.is_available(),
    # bf16=torch.cuda.is_available(),
    optim="paged_adamw_8bit" if USE_QLORA else "adamw_torch",
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,
    report_to="none",
    seed=int(SEED),
    # no_cuda=True ## To Test
)

In [99]:
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized['train'],
#     data_collator=data_collator,
#     # data_collator=sft_collator,
# )


# ## Set up SFTTrainer
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized["train"],
    # data_collator=sft_collator,
    data_collator=data_collator,
    args=training_args,
    # callbacks=[test_callback],  # Log every 50 steps
)


Truncating train dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9523/9523 [00:00<00:00, 69950.49 examples/s]


In [101]:
for batch in trainer.get_train_dataloader():
    # print(batch.keys())      # e.g., dict_keys(['input_ids', 'attention_mask', 'labels'])
    # print(batch["input_ids"].shape)
    for key, tensor in batch.items():
        print(key, tensor.device, tensor.dtype)
    break

input_ids cuda:0 torch.int64
labels cuda:0 torch.int64
attention_mask cuda:0 torch.int64


In [102]:
# Get a single batch from your trainer
batch = next(iter(trainer.get_train_dataloader()))

print("=== Batch tensor devices ===")
for key, tensor in batch.items():
    print(f"{key}: device={tensor.device}, dtype={tensor.dtype}, shape={tensor.shape}")

# Check model device
model_device = next(model.parameters()).device
print(f"\nModel device: {model_device}")

=== Batch tensor devices ===
input_ids: device=cuda:0, dtype=torch.int64, shape=torch.Size([2, 1024])
labels: device=cuda:0, dtype=torch.int64, shape=torch.Size([2, 1024])
attention_mask: device=cuda:0, dtype=torch.int64, shape=torch.Size([2, 1024])

Model device: cuda:0


In [103]:
device = next(model.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}


In [None]:
trainer.train()

In [93]:
next(model.parameters()).dtype

torch.float32

In [None]:
## Save the trained model
trainer.save_model(training_args.output_dir)
print(f"Model saved to {training_args.output_dir}")

In [None]:
## Inference Test
def generate_antibody(antigen: str, max_new_tokens: int = 700):
    prompt = tokenizer.bos_token + antigen + tokenizer.sep_token
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    return text.split(tokenizer.sep_token, 1)[-1].rstrip(tokenizer.eos_token)


test_antigen = "KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL"
print("Generated antibody:", generate_antibody(test_antigen))