In [1]:
!pip install -q transformers>=4.41.0 peft==0.7.1 accelerate==0.25.0
!pip install -q bitsandbytes==0.41.3 datasets wandb
!pip install -q fair-esm


In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import wandb
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [3]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

set_seed(42)

In [4]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using device: {device}")
if torch.cuda.is_available():
    print(f"GPU:{torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


using device: cuda
GPU:Tesla T4
Memory: 15.83 GB


In [1]:
import wandb

wandb.login(key="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")


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmanivarshithpc[0m ([33mmanivarshithpc-vignan-institute-of-technology-and-science[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from dataclasses import dataclass
from typing import List


@dataclass
class Config:
    model_name: str = "facebook/esm2-t33-650M-UR50D" 
    
    # LoRA configuration
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = None
    
    # Generation configuration
    max_seq_length: int = 64
    min_seq_length: int = 32
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.9
    
    # RL training configuration
    num_epochs: int = 5
    batch_size: int = 4
    gradient_accumulation_steps: int = 4
    num_sequences_per_batch: int = 8
    learning_rate: float = 5e-5
    kl_coef: float = 0.1
    clip_range: float = 0.2
    
    # Reward weights
    stability_weight: float = 1.0
    diversity_weight: float = 0.5
    constraint_weight: float = 0.5
    
    # Optimizer
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    warmup_steps: int = 100
    
    # Logging
    log_interval: int = 10
    save_interval: int = 100
    use_wandb: bool = True  # Enable W&B logging
    
    def __post_init__(self):
        if self.lora_target_modules is None:
            self.lora_target_modules = ["query", "key", "value"]

config = Config()


if config.use_wandb:
    wandb.init(
        project="protein-rl-design",
        config=vars(config),
        name="esm2-rl-experiment"
    )
