In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gc

# Clear PyTorch cache
torch.cuda.empty_cache()
gc.collect()
print("GPU memory cleared!")

# Load the base model and tokenizer
model_name = "microsoft/phi-4"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# ADD ALL THE SPECIAL TOKENS IN THE SAME ORDER AS TRAINING
# 1. Epitope tokens
epitope_tokens = ["<epi>", "</epi>"]
tokenizer.add_special_tokens({"additional_special_tokens": epitope_tokens})

# 2. Amino acids and other tokens
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
extra_tokens = amino_acids + ["|"]
new_tokens = [t for t in extra_tokens if t not in tokenizer.get_vocab()]
if new_tokens:
    tokenizer.add_tokens(new_tokens)

# 3. Task-specific tokens (this was missing!)
task_tokens = ["Antigen", "Antibody", "Epitope"]
tokenizer.add_tokens(task_tokens)

print(f"Tokenizer vocab size: {len(tokenizer)}")

# Load base model with your device map
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map={'model.embed_tokens': 0,
                'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0,
                'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0,
                'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0,
                'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0,
                'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0,
                'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0,
                'model.layers.24': 0, 'model.layers.25': 0,
                'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1,
                'model.layers.30': 1, 'model.layers.31': 1, 'model.layers.32': 1, 'model.layers.33': 1,
                'model.layers.34': 1, 'model.layers.35': 1, 'model.layers.36': 1, 'model.layers.37': 1,
                'model.layers.38': 1, 'model.layers.39': 1,
                'model.norm': 1, 'model.rotary_emb': 1, 'lm_head': 1},
    trust_remote_code=True,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# Resize model embeddings to match the tokenizer
base_model.resize_token_embeddings(len(tokenizer))

# Now the vocab size should match - check it
print(f"Model vocab size after resize: {base_model.get_input_embeddings().weight.shape[0]}")

# Load your trained LoRA adapters
model_path = "/home/nicholas/Documents/GitHub/peleke/models/peleke-phi-4-0806025"
model = PeftModel.from_pretrained(base_model, model_path)

print("Model loaded successfully!")

  from .autonotebook import tqdm as notebook_tqdm


GPU memory cleared!
Tokenizer vocab size: 100357


Loading checkpoint shards: 100%|██████████| 6/6 [00:11<00:00,  1.88s/it]
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Model vocab size after resize: 100357
Model loaded successfully!


In [2]:
# Function to create test prompts (same format as training)
def create_test_prompt(antigen_with_epitopes):
    return f"Antigen: {antigen_with_epitopes}<|im_end|>\nAntibody:"

# Test with some examples from your training data
test_antigens = [
    "KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL",
    "NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI[R]G[N]EV[S][Q]IAPGQ[T]GNIADYNYKLPDDFTGCVIAWNSN[K]LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF[R]P[T][Y][G][V]GH[Q]PYRVVVLSFELLHAPATVCGP",
]

for i, antigen in enumerate(test_antigens):
    prompt = create_test_prompt(antigen)
    print(f"Test prompt {i+1}:")
    print(prompt)
    print("-" * 50)

Test prompt 1:
Antigen: KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL<|im_end|>
Antibody:
--------------------------------------------------
Test prompt 2:
Antigen: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI[R]G[N]EV[S][Q]IAPGQ[T]GNIADYNYKLPDDFTGCVIAWNSN[K]LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF[R]P[T][Y][G][V]GH[Q]PYRVVVLSFELLHAPATVCGP<|im_end|>
Antibody:
--------------------------------------------------


In [9]:
def generate_antibody(model, tokenizer, antigen_with_epitopes, max_length=1000, temperature=0.7, top_p=0.9):
    """Generate antibody sequence from epitope-highlighted antigen"""
    
    # Create prompt
    prompt = create_test_prompt(antigen_with_epitopes)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=300,  # Max tokens to generate
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"),
            repetition_penalty=1.1,
        )
    
    # Decode the generated sequence
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Extract just the antibody part
    if "Antibody:" in generated_text:
        antibody_part = generated_text.split("Antibody:", 1)[1]
        if "<|im_end|>" in antibody_part:
            antibody_sequence = antibody_part.split("<|im_end|>", 1)[0].strip()
        else:
            antibody_sequence = antibody_part.strip()
    else:
        antibody_sequence = "Generation failed"
    
    return antibody_sequence, generated_text

# Test generation
test_antigen = "KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<epi>Y</epi><epi>S</epi>LG<epi>N</epi>WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA<epi>K</epi>KIVSDGNGMNAWVAWRNRCK<epi>G</epi><epi>T</epi><epi>D</epi>V<epi>Q</epi>AW<epi>I</epi><epi>R</epi>GCRL"

antibody_seq, full_generation = generate_antibody(model, tokenizer, test_antigen)

print("Generated Antibody Sequence:")
print(antibody_seq)
print("\nFull Generation:")
print(full_generation)

Generated Antibody Sequence:
QLVQSGAEVKKPGSSVKVSCTASGFNIKDYYAVSWVRQAPGQGLEWMGWISYNGDTNYAQRFQGRVTITADKSTRTAYMELTSDDSAVYFCARERGDGYFAVWGQGTLVTVSS|DIQLTQSPDSLAVSLGERATINCKSSQNNKNYLAWYQQKPGQPPKLLIFATSKLESGVPVRFSGSGSGTDFTLNIHPVEEEDAATYYCQQANSFPYTFGGGTKLEIK

Full Generation:
Antigen: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<epi>Y</epi><epi>S</epi>LG<epi>N</epi>WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA<epi>K</epi>KIVSDGNGMNAWVAWRNRCK<epi>G</epi><epi>T</epi><epi>D</epi>V<epi>Q</epi>AW<epi>I</epi><epi>R</epi>GCRL<|im_end|>Antibody: QLVQSGAEVKKPGSSVKVSCTASGFNIKDYYAVSWVRQAPGQGLEWMGWISYNGDTNYAQRFQGRVTITADKSTRTAYMELTSDDSAVYFCARERGDGYFAVWGQGTLVTVSS|DIQLTQSPDSLAVSLGERATINCKSSQNNKNYLAWYQQKPGQPPKLLIFATSKLESGVPVRFSGSGSGTDFTLNIHPVEEEDAATYYCQQANSFPYTFGGGTKLEIK<|im_end|>


In [4]:
import re

def convert_brackets_to_epi(sequence):
    """Convert [X] format to <epi>X</epi> format"""
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

# Test it
# test_sequence = "SCNGLYYQGSCYI[L]HSD[Y]KSFEDAK[D][Y]V[E][D][T]"
# converted = convert_brackets_to_epi(test_sequence)
# print("Original:", test_sequence)
# print("Converted:", converted)

# Convert multiple sequences
sequences_with_brackets = [
    "KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL",
    "NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI[R]G[N]EV[S][Q]IAPGQ[T]GNIADYNYKLPDDFTGCVIAWNSN[K]LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF[R]P[T][Y][G][V]GH[Q]PYRVVVLSFELLHAPATVCGP",
]

converted_sequences = [convert_brackets_to_epi(seq) for seq in sequences_with_brackets]

for orig, conv in zip(sequences_with_brackets, converted_sequences):
    print(f"Original: {orig}")
    print(f"Converted: {conv}")
    print("-" * 40)

Original: KVFGRCELAAAM[K][R]HGL[D][N][Y]RG[Y][S]LG[N]WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA[K]KIVSDGNGMNAWVAWRNRCK[G][T][D]V[Q]AW[I][R]GCRL
Converted: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<epi>Y</epi><epi>S</epi>LG<epi>N</epi>WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA<epi>K</epi>KIVSDGNGMNAWVAWRNRCK<epi>G</epi><epi>T</epi><epi>D</epi>V<epi>Q</epi>AW<epi>I</epi><epi>R</epi>GCRL
----------------------------------------
Original: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI[R]G[N]EV[S][Q]IAPGQ[T]GNIADYNYKLPDDFTGCVIAWNSN[K]LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF[R]P[T][Y][G][V]GH[Q]PYRVVVLSFELLHAPATVCGP
Converted: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI<epi>R</epi>G<epi>N</epi>EV<epi>S</epi><epi>Q</epi>IAPGQ<epi>T</epi>GNIADYNYKLPDDFTGCVIAWNSN<epi>K</epi>LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF<epi>R</epi>P

In [6]:
# Generate multiple examples with different parameters
test_antigens = [
    "KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<epi>Y</epi><epi>S</epi>LG<epi>N</epi>WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA<epi>K</epi>KIVSDGNGMNAWVAWRNRCK<epi>G</epi><epi>T</epi><epi>D</epi>V<epi>Q</epi>AW<epi>I</epi><epi>R</epi>GCRL",
    "NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI<epi>R</epi>G<epi>N</epi>EV<epi>S</epi><epi>Q</epi>IAPGQ<epi>T</epi>GNIADYNYKLPDDFTGCVIAWNSN<epi>K</epi>LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF<epi>R</epi>P<epi>T</epi><epi>Y</epi><epi>G</epi><epi>V</epi>GH<epi>Q</epi>PYRVVVLSFELLHAPATVCGP",
]

for i, antigen in enumerate(test_antigens):
    print(f"\n=== Test {i+1} ===")
    print(f"Input antigen: {antigen}")
    
    antibody_seq, _ = generate_antibody(model, tokenizer, antigen, temperature=0.7)
    print(f"Generated antibody: {antibody_seq}")


=== Test 1 ===
Input antigen: KVFGRCELAAAM<epi>K</epi><epi>R</epi>HGL<epi>D</epi><epi>N</epi><epi>Y</epi>RG<epi>Y</epi><epi>S</epi>LG<epi>N</epi>WVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCA<epi>K</epi>KIVSDGNGMNAWVAWRNRCK<epi>G</epi><epi>T</epi><epi>D</epi>V<epi>Q</epi>AW<epi>I</epi><epi>R</epi>GCRL
Generated antibody: EVQLVESGGGLVKPGGSLKLSCAASGFTFSNYAMSWVRQTPEKRLEWVASISAGGSYTYYADSVKGRFTISRDNARNILYLQMNSLKTEDTAIYYCTRGELTYDHWGQGTLVTVSS|DIVMTQSPLSLPVTPGEPASISCRSSQSLLHRSGHTYLHWYLQRPGQSPQVLIIFGDNNRFSGVPDRFSGSGSGTDFTLKISRVEAEDVGVYYCMQGTHWPRTFGQGTKVEIK

=== Test 2 ===
Input antigen: NLCPFHEVFNATTFASVYAWNRKRISNCVADYSVIYNFAPFFAFKCYGVSPTKLNDLCFTNVYADSFVI<epi>R</epi>G<epi>N</epi>EV<epi>S</epi><epi>Q</epi>IAPGQ<epi>T</epi>GNIADYNYKLPDDFTGCVIAWNSN<epi>K</epi>LDSKPSGNYNYLYRLLRKSKLKPFERDISTEIYQAGNKPCNGVAGPNCYSPLQSYGF<epi>R</epi>P<epi>T</epi><epi>Y</epi><epi>G</epi><epi>V</epi>GH<epi>Q</epi>PYRVVVLSFELLHAPATVCGP
Generated antibody: EVQLVESGGGLIQPGGSLRLSCAASAFTVSSNYMSWVRQAPGKGL