## 1. Environment & Hardware Verification
To perform efficient fine-tuning of **ESM-2** using **LoRA** and **RL**, a GPU is required to handle the high-dimensional tensor operations and gradient calculations.

* **Tool:** `nvidia-smi` (NVIDIA System Management Interface)
* **Purpose:** Confirms the presence of a CUDA-enabled device and monitors VRAM availability.
* **Safety Check:** If no GPU is detected, the script terminates execution to prevent slow CPU processing or Out-Of-Memory (OOM) errors.

In [1]:
import subprocess
import sys

print("Checking GPU availability...")
try:
    gpu_info = subprocess.check_output(['nvidia-smi'], text=True)
    print(" GPU detected!")
    print(gpu_info.split('\n')[8])  
except:
    print(" WARNING: No GPU detected! This notebook requires a GPU.")
    print("Go to Runtime > Change runtime type > Select 'T4 GPU'")
    sys.exit(1)

Checking GPU availability...
 GPU detected!
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |


## 2. Dependency Installation
This project utilizes the Hugging Face ecosystem and Meta's FAIR-ESM tools to implement a parameter-efficient training pipeline.

| Library | Primary Function |
| :--- | :--- |
| **transformers** | Provides the pre-trained ESM-2 model architecture and tokenizers. |
| **peft** | Implements **LoRA**, enabling the tuning of a fraction (~1%) of model parameters. |
| **accelerate** | Handles device placement and distributed training optimizations. |
| **fair-esm** | Native Meta AI tools for working with Evolutionary Scale Modeling (ESM) weights. |
| **wandb** | Used for experiment tracking and visualizing multi-objective reward trade-offs. |

In [2]:

!pip install -q transformers>=4.41.0 peft==0.7.1 accelerate==0.25.0
!pip install -q datasets wandb
!pip install -q fair-esm

print(" All packages installed successfully!")

 All packages installed successfully!


In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import wandb
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')


In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed(42)

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using device: {device}")
if torch.cuda.is_available():
    print(f"GPU:{torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


using device: cuda
GPU:Tesla T4
Memory: 15.83 GB


In [6]:
import wandb

wandb.login(key="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")


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmanivarshithpc[0m ([33mmanivarshithpc-vignan-institute-of-technology-and-science[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 3. Global Configuration & Experiment Setup
This section defines the architectural and behavioral parameters for the fine-tuning process. We utilize a structured `Config` class to ensure all hyperparameters are tracked.

### Key Components:
* **Model Backbone:** `ESM-2 (650M parameters)` - A large-scale protein language model.
* **LoRA Strategy:** Targets the **Self-Attention** modules (`q, k, v`) to adapt sequence generation with minimal parameter updates.
* **RL Steering:** * **KL Coefficient:** Controls the trade-off between exploring new sequences and staying close to the biologically-valid base model.
    * **Reward Weights:** A weighted sum approach to balance **Stability** (structural integrity), **Diversity** (novelty), and **Constraint Satisfaction**.
* **Logging:** Integrated with **Weights & Biases (WandB)** for real-time monitoring of reward convergence and sequence entropy.

In [7]:
@dataclass
class Config:
    
    model_name: str = "facebook/esm2_t33_650M_UR50D"  
    
    # LoRA configuration
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = None
    
    # Generation configuration
    max_seq_length: int = 64
    min_seq_length: int = 32
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.9
    
    # RL training configuration
    num_epochs: int = 5
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    num_sequences_per_batch: int = 8
    learning_rate: float = 5e-5
    kl_coef: float = 0.1  # KL penalty coefficient
    clip_range: float = 0.2  
    
    # Reward weights
    stability_weight: float = 1.0
    diversity_weight: float = 0.5
    constraint_weight: float = 0.5
    
    # Optimizer
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    warmup_steps: int = 100
    
    
    log_interval: int = 10
    save_interval: int = 100
    use_wandb: bool = True
    
    def __post_init__(self):
        if self.lora_target_modules is None:
            self.lora_target_modules = ["query", "key", "value"]

config = Config()


import gc
torch.cuda.empty_cache()
gc.collect()

print(f"\n Configuration loaded")
print(f"GPU Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


if config.use_wandb:
    wandb.init(
        project="protein-rl-design",
        config=vars(config),
        name="esm2-rl-experiment"
    )


 Configuration loaded
GPU Memory Available: 15.83 GB


## 4. Model Initialization & LoRA Adaptation
This section loads the ESM-2 weights and configures the **Parameter-Efficient Fine-Tuning (PEFT)** setup.

### Dual-Model Architecture:
1.  **Active Model (Actor):** The `base_model` integrated with **LoRA** adapters. This is the only part of the network that will undergo gradient updates.
2.  **Reference Model:** A frozen copy of the original ESM-2. It is used to calculate the **KL Penalty**, ensuring the RL agent remains within the manifold of biologically plausible sequences.

### Efficiency Optimizations:
* **FP16 Precision:** Reduces memory consumption to fit the 650M parameter model on a single GPU.
* **LoRA Injection:** Targets the Query, Key, and Value matrices, allowing the model to adapt its attention patterns toward desired structural traits with <1% trainable parameters.

In [8]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print(f"Vocabulary size: {len(tokenizer)}")


base_model = AutoModelForMaskedLM.from_pretrained(
    config.model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

base_model = base_model.to(device)

print(f"\nBase model parameters: {sum(p.numel() for p in base_model.parameters()) / 1e6:.2f}M")

# Configure LoRA 
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    target_modules=config.lora_target_modules,
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION  
)


model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()


ref_model = AutoModelForMaskedLM.from_pretrained(
    config.model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)
ref_model = ref_model.to(device)
ref_model.eval()
for param in ref_model.parameters():
    param.requires_grad = False

print("\n Models loaded successfully!")

`torch_dtype` is deprecated! Use `dtype` instead!


Vocabulary size: 33

Base model parameters: 651.04M
trainable params: 2,027,520 || all params: 653,070,774 || trainable%: 0.3104594602483314

 Models loaded successfully!


## 5. Autoregressive Protein Generation
Since **ESM-2** is natively a Masked Language Model, this module implements a custom generation wrapper to enable **de novo** protein design through iterative token prediction.

### Key Sampling Techniques:
* **Top-K Filtering:** Restricts sampling to the $K$ most likely next amino acids to prevent high-error transitions.
* **Top-P (Nucleus) Sampling:** Dynamically chooses the smallest set of tokens whose cumulative probability exceeds threshold $P$, balancing diversity and structural validity.
* **Log-Probability Tracking:** Crucial for the **Policy Gradient** updates in Reinforcement Learning, allowing the agent to correlate specific sequence choices with resulting rewards.

### Generation Workflow:
1.  **Initialize:** Start with the `[CLS]` token.
2.  **Predict:** Generate logits for the next amino acid position.
3.  **Filter & Sample:** Apply temperature, Top-K, and Top-P to select the next token.
4.  **Iterate:** Append the token and repeat until `max_length` or `[EOS]` is reached.

In [9]:
class ProteinGenerator:
    """Handles autoregressive protein sequence generation"""
    
    def __init__(self, model, tokenizer, config):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.device = next(model.parameters()).device
        
        
        self.cls_token_id = tokenizer.cls_token_id
        self.eos_token_id = tokenizer.eos_token_id
        self.pad_token_id = tokenizer.pad_token_id
        self.mask_token_id = tokenizer.mask_token_id
        
    def generate_sequences(
        self, 
        num_sequences: int,
        temperature: float = None,
        max_length: int = None,
        return_log_probs: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
        """
        Generate protein sequences autoregressively
        
        Returns:
            sequences: Token IDs (batch_size, seq_len)
            log_probs: Log probabilities (batch_size, seq_len)
            sequences_str: Decoded sequences
        """
        temperature = temperature or self.config.temperature
        max_length = max_length or self.config.max_seq_length
        
        self.model.eval()
        
        # Start with CLS token
        sequences = torch.full(
            (num_sequences, 1), 
            self.cls_token_id, 
            dtype=torch.long, 
            device=self.device
        )
        
        log_probs_list = []
        
        with torch.no_grad():
            for step in range(max_length - 1):
                
                outputs = self.model(sequences)
                logits = outputs.logits[:, -1, :]  # (batch_size, vocab_size)
                
                
                logits = logits / temperature
                
                
                logits = self._top_k_top_p_filtering(
                    logits, 
                    top_k=self.config.top_k, 
                    top_p=self.config.top_p
                )
                
                
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                
                if return_log_probs:
                    log_prob = F.log_softmax(logits, dim=-1)
                    selected_log_probs = log_prob.gather(1, next_token)
                    log_probs_list.append(selected_log_probs)
                sequences = torch.cat([sequences, next_token], dim=1)
                
                if (next_token == self.eos_token_id).all():
                    break
        
        
        if return_log_probs:
            log_probs = torch.cat(log_probs_list, dim=1)
        else:
            log_probs = None
        
        sequences_str = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
        return sequences, log_probs, sequences_str
    
    def _top_k_top_p_filtering(
        self, 
        logits: torch.Tensor, 
        top_k: int = 0, 
        top_p: float = 1.0
    ) -> torch.Tensor:
        
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = float('-inf')
        
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(
                1, sorted_indices, sorted_indices_to_remove
            )
            logits[indices_to_remove] = float('-inf')
        
        return logits

generator = ProteinGenerator(model, tokenizer, config)


print("\n" + "="*80)
print("Testing Sequence Generation")
print("="*80)

test_seqs, test_logprobs, test_strs = generator.generate_sequences(
    num_sequences=3,
    max_length=50
)

print("\nGenerated sequences:")
for i, seq in enumerate(test_strs):
    print(f"{i+1}. {seq} (length: {len(seq)})")


Testing Sequence Generation

Generated sequences:
1. W W W W W W W W W W W W W W W W N N N N N N N N N N N N (length: 55)
2. E E E E E E E E E E E E Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y (length: 79)
3. S S S S S S S S S S S S S S S S S S S S S S N N N (length: 49)
