# Generate a novel antibody sequence with `peleke-mistral-7b-instruct-v0.2`

In [None]:
import pandas as pd
import torch
import re
import gc
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig

def clear_gpu_memory():
    """Clear GPU memory and cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

def setup_mistral_tokenizer(tokenizer, model):
    """Setup tokenizer with epitope and amino acid tokens for Mistral"""
    print("Setting up Mistral-specific tokenizer...")
    
    # Get existing vocabulary
    existing_vocab = set(tokenizer.get_vocab().keys())
    
    # Define all tokens you want to add
    epitope_tokens = ["<epi>", "</epi>"]
    task_tokens = ["Antigen", "Antibody", "Epitope"]
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
    separator_tokens = ["|"]
    all_desired_tokens = epitope_tokens + task_tokens + amino_acids + separator_tokens
    
    # Only add tokens that don't exist
    tokens_to_add = [token for token in all_desired_tokens if token not in existing_vocab]
    
    if tokens_to_add:
        num_added = tokenizer.add_special_tokens({
            "additional_special_tokens": (tokenizer.additional_special_tokens or []) + tokens_to_add
        })
        model.resize_token_embeddings(len(tokenizer))
        print(f"Added {num_added} new tokens: {tokens_to_add}")
    else:
        print("No new tokens needed")
    
    return tokenizer, model

def convert_epitope_format(sequence):
    """Convert [X] format to <epi>X</epi> format"""
    return re.sub(r'\[([A-Z])\]', r'<epi>\1</epi>', sequence)

def format_mistral_prompt(antigen_sequence):
    """Format the antigen sequence for Mistral model"""
    formatted_antigen = convert_epitope_format(antigen_sequence)
    prompt = f"Antigen: <s>{formatted_antigen}</s>\nAntibody:"
    return prompt

def load_mistral_model(model_path):
    """Load Mistral model with proper tokenizer setup and vocabulary resizing"""
    print(f"Loading Mistral model: {model_path}...")
    
    clear_gpu_memory()
    
    try:
        # Try to load PEFT config first
        config = PeftConfig.from_pretrained(model_path)
        is_peft_model = True
        base_model_name = config.base_model_name_or_path
        print(f"Detected PEFT model with base: {base_model_name}")
    except Exception as e:
        print(f"Not a PEFT model or couldn't load config: {e}")
        is_peft_model = False
        base_model_name = model_path
    
    if is_peft_model:
        try:
            # Load tokenizer
            try:
                print("Loading tokenizer from adapter path...")
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
                expected_vocab_size = len(tokenizer)
                print(f"Tokenizer loaded with vocab size: {expected_vocab_size}")
            except Exception as e:
                print(f"Couldn't load tokenizer from adapter path: {e}")
                print("Loading tokenizer from base model...")
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
                expected_vocab_size = 32005
                print(f"Using expected vocab size: {expected_vocab_size}")
            
            # Set pad token if needed
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Load base model
            print(f"Loading base model: {base_model_name}")
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_name,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            
            # Resize embeddings if needed
            current_vocab_size = base_model.config.vocab_size
            print(f"Base model vocab size: {current_vocab_size}")
            
            if current_vocab_size != expected_vocab_size:
                print(f"Resizing embeddings from {current_vocab_size} to {expected_vocab_size}")
                base_model.resize_token_embeddings(expected_vocab_size)
                print("Embeddings resized successfully")
            
            # Load PEFT adapters
            print("Loading PEFT adapters...")
            model = PeftModel.from_pretrained(
                base_model, 
                model_path,
                is_trainable=False
            )
            print("PEFT model loaded successfully")
            
        except RuntimeError as e:
            if "size mismatch" in str(e):
                print("\n" + "="*60)
                print("VOCABULARY SIZE MISMATCH DETECTED")
                print("="*60)
                print("The model has a vocabulary size mismatch.")
                print("This typically happens when the model was fine-tuned with additional tokens.")
                print("\nAttempting alternative loading method...")
                
                # Try manual fix
                model, tokenizer = load_mistral_model_with_manual_resize(model_path)
            else:
                raise e
        
    else:
        # Load as regular model
        print("Loading as regular model...")
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            ignore_mismatched_sizes=True
        )
    
    # Setup tokenizer with special tokens
    tokenizer, model = setup_mistral_tokenizer(tokenizer, model)
    
    model.eval()
    
    device = next(model.parameters()).device
    print(f"Model loaded on device: {device}")
    print(f"Final vocab size: {len(tokenizer)}")
    
    return model, tokenizer

def load_mistral_model_with_manual_resize(model_path, target_vocab_size=32005):
    """Alternative loading method with manual vocabulary resize"""
    print("Using manual vocabulary resize method...")
    
    clear_gpu_memory()
    
    # Load PEFT config
    config = PeftConfig.from_pretrained(model_path)
    base_model_name = config.base_model_name_or_path
    
    # Load tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    except:
        tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    print(f"Loading base model: {base_model_name}")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    )
    
    # Force resize to target vocabulary size
    print(f"Force resizing embeddings to {target_vocab_size}")
    base_model.resize_token_embeddings(target_vocab_size)
    
    # Now load PEFT adapters
    print("Loading PEFT adapters with resized model...")
    model = PeftModel.from_pretrained(
        base_model,
        model_path,
        is_trainable=False
    )
    
    # Setup tokenizer
    tokenizer, model = setup_mistral_tokenizer(tokenizer, model)
    
    model.eval()
    print("Model loaded successfully with manual resize!")
    
    return model, tokenizer

def generate_mistral_antibody(model, tokenizer, antigen_seq):
    """Generate antibody sequence using Mistral model"""
    try:
        # Format prompt for Mistral
        prompt = format_mistral_prompt(antigen_seq)
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        
        # Move to GPU
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Store the input length for later use
        input_length = inputs['input_ids'].shape[1]
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=800,
                do_sample=True,
                temperature=0.7,
                pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
                use_cache=False,
            )
        
        # Only decode the generated part (exclude the prompt)
        generated_tokens = outputs[0][input_length:]
        antibody_sequence = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        
        return antibody_sequence
        
    except Exception as e:
        print(f"Error generating sequence: {e}")
        return f"ERROR: {str(e)}"

def generate_antibodies_for_single_antigen(antigen_sequence, num_samples=20, model_path="silicobio/peleke-mistral-7b-instruct-v0.2"):
    """Generate multiple antibody samples for a single antigen"""
    
    print(f"Loading model: {model_path}")
    model, tokenizer = load_mistral_model(model_path)
    
    print(f"\nGenerating {num_samples} antibody samples for antigen...")
    print(f"Antigen sequence: {antigen_sequence[:100]}...")
    print(f"Converted format: {convert_epitope_format(antigen_sequence)[:100]}...")
    
    results = []
    
    try:
        for i in range(num_samples):
            print(f"Generating sample {i+1}/{num_samples}...")
            
            # Generate antibody sequence
            antibody_seq = generate_mistral_antibody(model, tokenizer, antigen_sequence)
            
            # Store result
            result = {
                'sample_id': f"antibody_{i+1:02d}",
                'antigen_sequence': antigen_sequence,
                'antibody_sequence': antibody_seq,
                'sequence_length': len(antibody_seq),
                'contains_error': 'ERROR' in antibody_seq
            }
            
            results.append(result)
            
            # Show progress
            if (i + 1) % 5 == 0:
                print(f"  Generated {i + 1}/{num_samples} samples")
    
    finally:
        # Clean up
        print("Cleaning up model...")
        del model, tokenizer
        clear_gpu_memory()
    
    return results

def save_results_to_csv(results, filename="Mistral_antibodies.csv"):
    """Save results to CSV file"""
    df = pd.DataFrame(results)
    df.to_csv(filename, index=False)
    
    print(f"\nResults saved to: {filename}")
    print(f"Total sequences generated: {len(df)}")
    print(f"Successful generations: {len(df[~df['contains_error']])}")
    print(f"Failed generations: {len(df[df['contains_error']])}")
    print(f"Average sequence length: {df[~df['contains_error']]['sequence_length'].mean():.1f} characters")
    
    return df

def validate_antibody_sequences(results):
    """Basic validation of generated antibody sequences"""
    print("\n" + "="*60)
    print("ANTIBODY SEQUENCE VALIDATION")
    print("="*60)
    
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")
    validation_results = []
    
    for result in results:
        antibody_seq = result['antibody_sequence']
        
        # Skip error sequences
        if result['contains_error']:
            validation_results.append({
                'sample_id': result['sample_id'],
                'is_valid': False,
                'error_type': 'Generation error',
                'sequence_length': 0
            })
            continue
        
        # Check for valid amino acids only
        seq_chars = set(antibody_seq.upper())
        invalid_chars = seq_chars - valid_amino_acids
        
        # Check length (typical antibodies are 100-150 amino acids for Fv region)
        seq_length = len(antibody_seq)
        reasonable_length = 50 <= seq_length <= 300
        
        is_valid = len(invalid_chars) == 0 and reasonable_length
        
        validation_results.append({
            'sample_id': result['sample_id'],
            'is_valid': is_valid,
            'sequence_length': seq_length,
            'invalid_chars': list(invalid_chars) if invalid_chars else None,
            'reasonable_length': reasonable_length
        })
    
    # Summary
    valid_count = sum(1 for v in validation_results if v['is_valid'])
    total_count = len(validation_results)
    
    print(f"Valid antibody sequences: {valid_count}/{total_count} ({100*valid_count/total_count:.1f}%)")
    
    return validation_results

# Main execution
if __name__ == "__main__":
    # Define your single antigen sequence
    antigen_sequence = "ICLQKTSNQILKPKLISYTLGQSGTCITDPLLAMDEGYFAYSHLERIG[S][C][S][R]GVSKQRIIGVGEVLDRGDEVPSLFMTNVWTPPNPNTVYHCSAVYNNEFYYVLCAVSTVGDPI[L]NSTYWSGSLMMTRLAVKPKSNGGGYNQHQLALRSIEKGRYDKVMPYGPSGIKQGDTLYFPAVGFLVRTEFKYNDSNCPITKC[Q][Y]SKPENCRLSMG[I][R]PNSHYILRSGLLKYNLSDGENPKVVFIEISDQRLSIGSPSKIYDSLGQPVFYQAS[F]SWDTMIKFGDVLTVNPLVVNWRNNTVISR[P][G][Q][S][Q]CPRFNTCP[E]IC[W][E][G][V]YNDAFLIDRINWISAGVFLDSN[Q][T][A][E]NPVFTVFKDNEILYRAQLASE[D]T[N][A][Q]KTITNCFLLKNKIWCISLV[E][I][Y]D[T]GDNV[I]RPKLFAVKIPEQCTH"
    
    print("Starting single antigen antibody generation...")
    
    # Generate antibodies
    results = generate_antibodies_for_single_antigen(
        antigen_sequence=antigen_sequence,
        num_samples=20
    )
    
    # Save results
    df = save_results_to_csv(results)
    
    # Validate sequences
    validation_results = validate_antibody_sequences(results)
    
    # Show sample results
    print(f"\nSample generated antibodies:")
    successful_results = [r for r in results if not r['contains_error']]
    
    for i, result in enumerate(successful_results[:3]):
        print(f"\nSample {result['sample_id']}:")
        print(f"Length: {result['sequence_length']} amino acids")
        print(f"Sequence: {result['antibody_sequence'][:100]}...")
    
    print(f"\nGeneration completed! Check 'single_antigen_antibodies.csv' for full results.")

Starting single antigen antibody generation...
Loading model: silicobio/peleke-mistral-7b-instruct-v0.2
Loading Mistral model: silicobio/peleke-mistral-7b-instruct-v0.2...
Detected PEFT model with base: mistralai/Mistral-7B-Instruct-v0.2
Loading tokenizer from adapter path...
Tokenizer loaded with vocab size: 32000
Loading base model: mistralai/Mistral-7B-Instruct-v0.2


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.56it/s]


Base model vocab size: 32000
Loading PEFT adapters...

VOCABULARY SIZE MISMATCH DETECTED
The model has a vocabulary size mismatch.
This typically happens when the model was fine-tuned with additional tokens.

Attempting alternative loading method...
Using manual vocabulary resize method...
Loading base model: mistralai/Mistral-7B-Instruct-v0.2


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.56it/s]
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Force resizing embeddings to 32005
Loading PEFT adapters with resized model...
Setting up Mistral-specific tokenizer...
Added 5 new tokens: ['<epi>', '</epi>', 'Antigen', 'Antibody', 'Epitope']
Model loaded successfully with manual resize!
Setting up Mistral-specific tokenizer...
No new tokens needed
Model loaded on device: cuda:0
Final vocab size: 32005

Generating 20 antibody samples for antigen...
Antigen sequence: ICLQKTSNQILKPKLISYTLGQSGTCITDPLLAMDEGYFAYSHLERIG[S][C][S][R]GVSKQRIIGVGEVLDRGDEVPSLFMTNVWTPPNPNTVYHC...
Converted format: ICLQKTSNQILKPKLISYTLGQSGTCITDPLLAMDEGYFAYSHLERIG<epi>S</epi><epi>C</epi><epi>S</epi><epi>R</epi>GVSK...
Generating sample 1/20...
Generating sample 2/20...
Generating sample 3/20...
Generating sample 4/20...
Generating sample 5/20...
  Generated 5/20 samples
Generating sample 6/20...
Generating sample 7/20...
Generating sample 8/20...
Generating sample 9/20...
Generating sample 10/20...
  Generated 10/20 samples
Generating sample 11/20...
Generating sa