In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import torch

model_path = 'silicobio/peleke-mistral-7b-instruct-v0.2'

# Load configuration and tokenizer
config = PeftConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Set pad token if not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    device_map="auto",  # Automatically handle device placement
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

# IMPORTANT: Resize embeddings to match the fine-tuned vocabulary size
expected_vocab_size = 32005  # The fine-tuned model has 5 additional tokens
base_model.resize_token_embeddings(expected_vocab_size)

# Load PEFT adapters
model = PeftModel.from_pretrained(
    base_model, 
    model_path,
    is_trainable=False  # Set to False for inference
)

model.eval()


  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.15s/it]
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32005, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora

In [2]:
import re

def convert_epitope_format(sequence):
    """Convert [X] format to <epi>X</epi> format"""
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

def format_prompt(antigen_sequence):
    """Format the antigen sequence for Mistral model"""
    formatted_antigen = convert_epitope_format(antigen_sequence)
    prompt = f"Antigen: <s>{formatted_antigen}</s>\nAntibody:"
    return prompt


In [3]:
def generate_antibody(model, tokenizer, antigen_sequence):
    # Format the prompt
    prompt = format_prompt(antigen_sequence)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    
    # Move to device
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Store input length to extract only generated tokens
    input_length = inputs['input_ids'].shape[1]
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=800,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False,
        )
    
    # Decode only the generated part
    generated_tokens = outputs[0][input_length:]
    antibody_sequence = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
    
    return antibody_sequence



In [4]:

antigen = "ICLQKTSNQILKPKLISYTLGQSGTCITDPLLAMDEGYFAYSHLERIG[S][C][S][R]GVSKQRIIGVGEVLDRGDEVPSLFMTNVWTPPNPNTVYHCSAVYNNEFYYVLCAVSTVGDPI[L]NSTYWSGSLMMTRLAVKPKSNGGGYNQHQLALRSIEKGRYDKVMPYGPSGIKQGDTLYFPAVGFLVRTEFKYNDSNCPITKC[Q][Y]SKPENCRLSMG[I][R]PNSHYILRSGLLKYNLSDGENPKVVFIEISDQRLSIGSPSKIYDSLGQPVFYQAS[F]SWDTMIKFGDVLTVNPLVVNWRNNTVISR[P][G][Q][S][Q]CPRFNTCP[E]IC[W][E][G][V]YNDAFLIDRINWISAGVFLDSN[Q][T][A][E]NPVFTVFKDNEILYRAQLASE[D]T[N][A][Q]KTITNCFLLKNKIWCISLV[E][I][Y]D[T]GDNV[I]RPKLFAVKIPEQCTH"
antibody = generate_antibody(model, tokenizer, antigen)
print(f"Antigen: {antigen}\nAntibody: {antibody}")


Antigen: ICLQKTSNQILKPKLISYTLGQSGTCITDPLLAMDEGYFAYSHLERIG[S][C][S][R]GVSKQRIIGVGEVLDRGDEVPSLFMTNVWTPPNPNTVYHCSAVYNNEFYYVLCAVSTVGDPI[L]NSTYWSGSLMMTRLAVKPKSNGGGYNQHQLALRSIEKGRYDKVMPYGPSGIKQGDTLYFPAVGFLVRTEFKYNDSNCPITKC[Q][Y]SKPENCRLSMG[I][R]PNSHYILRSGLLKYNLSDGENPKVVFIEISDQRLSIGSPSKIYDSLGQPVFYQAS[F]SWDTMIKFGDVLTVNPLVVNWRNNTVISR[P][G][Q][S][Q]CPRFNTCP[E]IC[W][E][G][V]YNDAFLIDRINWISAGVFLDSN[Q][T][A][E]NPVFTVFKDNEILYRAQLASE[D]T[N][A][Q]KTITNCFLLKNKIWCISLV[E][I][Y]D[T]GDNV[I]RPKLFAVKIPEQCTH
Antibody: VGPRTTLTVSLRSGASVKMSCKASGYSFTWVRQKPGQGLEWVKISYDGLKDYTNYKFQGVKATITADKSSNTAYLQISNLTSEDTAVYFCSRRAVYYDYWGQGTTLTVSS|VTVALGTVSLAPGTVSLRSCRASQSVSLSYLHWYQQKPGQAPLLVYGDNSKRPSGIPDRFSGSSSGNTASLTISGVQAEDEADYYCQSSDSSNWVFGGGTKLTVL
