# P2 Sampling Demo for Protein Sequence Generation

This notebook demonstrates how to use P2 (Path Planning) sampling to generate protein sequences.

## Setup

First, let's import the necessary libraries:

In [None]:
import torch
import math
import time
from transformers import AutoTokenizer, EsmForMaskedLM
from path_planning.p2 import p2_sampling
from path_planning.utils import seed_everything


## Helper Functions

Let's define some helper functions for our demo:

In [None]:
def ignore_special_tokens_logits(logits, tokenizer):
    """Masks out the logits of special tokens to prevent them from being sampled."""
    logits[..., tokenizer.mask_token_id] = -math.inf
    logits[..., tokenizer._token_to_id["X"]] = -math.inf
    logits[..., tokenizer.pad_token_id] = -math.inf
    logits[..., tokenizer.cls_token_id] = -math.inf
    logits[..., tokenizer.eos_token_id] = -math.inf
    return logits

class ModelWrapper:
    """Wrapper for the ESM model to handle logits processing."""
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def __call__(self, x):
        outputs = self.model(x)
        logits = outputs.logits
        return ignore_special_tokens_logits(logits.float(), self.tokenizer)

def create_masked_sequence(sequence_length, tokenizer, batch_size=1, device='cuda'):
    """Create a fully masked sequence for generation."""
    seq = [tokenizer.mask_token] * sequence_length
    sequences = [''.join(seq)] * batch_size
    
    encoded = tokenizer(
        sequences,
        add_special_tokens=True,
        padding=True,
        return_tensors='pt'
    )
    return encoded['input_ids'].to(device)

## Configuration

Set the parameters for protein sequence generation:

In [None]:
# Configuration
model_name = "airkingbd/dplm_650m"  # You can also try "zhangzhi/EvoFlow-150M-fs"
num_seqs = 5  # Number of sequences to generate
seq_len = 100  # Length of sequences
num_steps = 100  # Number of P2 sampling steps
temperature = 1.0  # Sampling temperature
eta = 1.0  # Stochasticity strength
seed = 42  # Random seed
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set random seed for reproducibility
seed_everything(seed)

## Load Model

Load the protein language model:

In [None]:
print(f"Loading model {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name)
model = model.eval().to(device)

# Wrap the model
model_wrapper = ModelWrapper(model, tokenizer)

## Create Initial Sequence

Create a fully masked sequence as the starting point:

In [None]:
print("Creating initial sequence...")
xt = create_masked_sequence(
    sequence_length=seq_len,
    tokenizer=tokenizer,
    batch_size=num_seqs,
    device=device
)
print(f"Initial sequence shape: {xt.shape}")

## Run P2 Sampling

Generate protein sequences using P2 sampling:

In [None]:
print("Starting P2 sampling...")
start_time = time.time()
# check out p2_sampling to see the full parameters
sampled_xt = p2_sampling(
    xt=xt,
    model=model_wrapper,
    tokenizer=tokenizer,
    num_steps=num_steps,
    tau=temperature,
    eta=eta
)

elapsed_time = time.time() - start_time
print(f"Generation completed in {elapsed_time:.2f} seconds")
print(f"Tokens/second: {num_seqs * seq_len / elapsed_time:.2f}")

## Decode and Display Results

Decode the generated sequences and display them:

In [None]:
# Decode sequences
decoded_seqs = tokenizer.batch_decode(sampled_xt, skip_special_tokens=True)
decoded_seqs = [''.join(seq.split()) for seq in decoded_seqs]

# Display generated sequences
print("\nGenerated Protein Sequences:")
for i, seq in enumerate(decoded_seqs):
    print(f"Sequence {i+1} (length {len(seq)}):")
    print(seq)
    print()

## Save Sequences (Optional)

Save the generated sequences to a FASTA file:

In [None]:
def save_sequences_to_fasta(sequences, seq_len, save_path):
    """Save generated sequences to FASTA format."""
    import os
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as fp:
        for idx, seq in enumerate(sequences):
            fp.write(f">SEQUENCE_{idx}_L={seq_len}\n")
            fp.write(f"{seq}\n")

# Uncomment to save sequences
# save_path = "generated_sequences.fasta"
# save_sequences_to_fasta(decoded_seqs, seq_len, save_path)
# print(f"Saved sequences to {save_path}")