<a href="https://colab.research.google.com/github/sean6211/PTM-predictor-evaluation/blob/main/ESMFold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ESMFold**
for more details see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
- for homo-oligomeric predictions, set copies > 1
- See [experimental notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/ESMFold_advanced.ipynb) for more advanced options (like sampling).

#### **Colab Limitations**
- For short monomeric proteins under the length 400, consider using [ESMFold API](https://esmatlas.com/resources?action=fold) (no need for GPU, super fast!)
- On Tesla T4 (typical free colab GPU), max total length ~ 900

In [None]:
%%time
#@title install
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)
version = "1" # @param ["0", "1"]
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

installing libs...
installing openfold...
installing esmfold...
downloading params...


In [None]:
#@title ##run **ESMFold**
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
import torch
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import gc

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

jobname = "test" #@param {type:"string"}
jobname = re.sub(r'\W+', '', jobname)[:50]

sequence = "GWSTELEKHREELKEFLKKEGITNVEIRIDNGRLEVRVEGGTERLKRFLEELRQKLEKKGYTVDIKIE" #@param {type:"string"}
sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("^[:]+","",sequence)
sequence = re.sub("[:]+$","",sequence)
copies = 1 #@param {type:"integer"}
if copies == "" or copies <= 0: copies = 1
sequence = ":".join([sequence] * copies)
num_recycles = 3 #@param ["0", "1", "2", "3", "6", "12", "24"] {type:"raw"}
chain_linker = 25

ID = jobname+"_"+get_hash(sequence)[:5]
seqs = sequence.split(":")
lengths = [len(s) for s in seqs]
length = sum(lengths)
print("length",length)

u_seqs = list(set(seqs))
if len(seqs) == 1: mode = "mono"
elif len(u_seqs) == 1: mode = "homo"
else: mode = "hetero"

if "model" not in dir() or model_name != model_name_:
  if "model" in dir():
    # delete old model from memory
    del model
    gc.collect()
    if torch.cuda.is_available():
      torch.cuda.empty_cache()

  model = torch.load(model_name, weights_only=False)
  model.eval().cuda().requires_grad_(False)
  model_name_ = model_name

# optimized for Tesla T4
if length > 700:
  model.set_chunk_size(64)
else:
  model.set_chunk_size(128)

torch.cuda.empty_cache()
output = model.infer(sequence,
                     num_recycles=num_recycles,
                     chain_linker="X"*chain_linker,
                     residue_index_offset=512)

pdb_str = model.output_to_pdb(output)[0]
output = tree_map(lambda x: x.cpu().numpy(), output)
ptm = output["ptm"][0]
plddt = output["plddt"][0,...,1].mean()
O = parse_output(output)
print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')
os.system(f"mkdir -p {ID}")
prefix = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_default"
np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
with open(f"{prefix}.pdb","w") as out:
  out.write(pdb_str)

In [None]:
#@title display (optional) {run: "auto"}
import py3Dmol
pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):

  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view

color = "confidence" #@param ["confidence", "rainbow", "chain"]
if color == "confidence": color = "pLDDT"
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
show_pdb(pdb_str, color=color,
         show_sidechains=show_sidechains,
         show_mainchains=show_mainchains,
         Ls=lengths).show()

In [None]:
#@title plot confidence (optional)

dpi = 100 #@param {type:"integer"}

def plot_ticks(Ls):
  Ln = sum(Ls)
  L_prev = 0
  for L_i in Ls[:-1]:
    L = L_prev + L_i
    L_prev += L_i
    plt.plot([0,Ln],[L,L],color="black")
    plt.plot([L,L],[0,Ln],color="black")
  ticks = np.cumsum([0]+Ls)
  ticks = (ticks[1:] + ticks[:-1])/2
  plt.yticks(ticks,alphabet_list[:len(ticks)])

def plot_confidence(O, Ls=None, dpi=100):
  if "lm_contacts" in O:
    plt.figure(figsize=(20,4), dpi=dpi)
    plt.subplot(1,4,1)
  else:
    plt.figure(figsize=(15,4), dpi=dpi)
    plt.subplot(1,3,1)

  plt.title('Predicted lDDT')
  plt.plot(O["plddt"])
  if Ls is not None:
    L_prev = 0
    for L_i in Ls[:-1]:
      L = L_prev + L_i
      L_prev += L_i
      plt.plot([L,L],[0,100],color="black")
  plt.xlim(0,O["plddt"].shape[0])
  plt.ylim(0,100)
  plt.ylabel('plDDT')
  plt.xlabel('position')
  plt.subplot(1,4 if "lm_contacts" in O else 3,2)

  plt.title('Predicted Aligned Error')
  Ln = O["pae"].shape[0]
  plt.imshow(O["pae"],cmap="bwr",vmin=0,vmax=30,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  plt.colorbar()
  plt.xlabel('Scored residue')
  plt.ylabel('Aligned residue')

  if "lm_contacts" in O:
    plt.subplot(1,4,3)
    plt.title("contacts from LM")
    plt.imshow(O["lm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
    if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
    plt.subplot(1,4,4)
  else:
    plt.subplot(1,3,3)
  plt.title("contacts from Structure Module")
  plt.imshow(O["sm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  return plt

plot_confidence(O, Ls=lengths, dpi=dpi)
plt.savefig(f'{prefix}.png',bbox_inches='tight')
plt.show()

In [None]:
#@title download predictions
from google.colab import files
os.system(f"zip {ID}.zip {ID}/*")
files.download(f'{ID}.zip')

In [None]:
%%time
# Install official ESM + ESMFold and utilities
!pip install -q "fair-esm[esmfold]" biotite

import math
import random
from pathlib import Path
from tempfile import NamedTemporaryFile

import numpy as np
import torch
import biotite.structure.io as bsio
import esm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#############################################
# Load ESM-2 (35M) sequence model
#############################################

lm_model, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
lm_model = lm_model.to(device)
lm_model.eval()  # no backprop, ES is gradient-free

AA_LETTERS = list("ACDEFGHIKLMNPQRSTVWY")
AA_INDICES = torch.tensor([alphabet.get_idx(a) for a in AA_LETTERS], device=device)

mask_idx = alphabet.mask_idx
cls_idx  = alphabet.cls_idx
eos_idx  = alphabet.eos_idx

print("Loaded ESM-2 35M model.")

#############################################
# Load ESMFold structure model (black-box)
#############################################

fold_model = esm.pretrained.esmfold_v1()
fold_model = fold_model.eval().to(device)
fold_model.set_chunk_size(128)  # mildly more memory-efficient on Colab

print("Loaded ESMFold v1.")


In [None]:
@torch.no_grad()
def sample_sequence_from_lm(
    model,
    alphabet,
    length=100,
    temperature=1.0,
    device=device,
):
    """
    Simple left-to-right generation using ESM-2 as a pseudo autoregressive LM.
    """
    total_len = length + 2  # [CLS] + L tokens + [EOS]
    tokens = torch.full((1, total_len), fill_value=mask_idx, device=device, dtype=torch.long)
    tokens[0, 0] = cls_idx
    tokens[0, -1] = eos_idx

    for pos in range(1, length + 1):
        tokens[0, pos] = mask_idx
        out = model(tokens, repr_layers=[], return_contacts=False)
        logits = out["logits"][0, pos]  # (vocab_size,)

        # Restrict to the 20 standard amino acids
        aa_logits = logits[AA_INDICES]
        probs = torch.softmax(aa_logits / temperature, dim=-1)

        # Sample one amino acid
        aa_idx = torch.multinomial(probs, num_samples=1)
        tok_id = AA_INDICES[aa_idx]
        tokens[0, pos] = tok_id

    seq_tokens = tokens[0, 1:-1].tolist()
    seq = "".join(alphabet.get_tok(t) for t in seq_tokens)
    return seq


# Quick sanity check
test_seq = sample_sequence_from_lm(lm_model, alphabet, length=60, temperature=1.0)
print("Example generated sequence:", test_seq)
print("Length:", len(test_seq))

In [None]:
@torch.no_grad()
def evaluate_sequence_with_esmfold(
    sequence: str,
    fold_model=fold_model,
    device=device,
) -> float:
    """
    Black-box reward: ESMFold structure prediction + mean pLDDT.
    """
    pdb_str = fold_model.infer_pdb(sequence)

    with NamedTemporaryFile("w+", suffix=".pdb") as tmp:
        tmp.write(pdb_str)
        tmp.flush()
        struct = bsio.load_structure(tmp.name, extra_fields=["b_factor"])

    mean_plddt = float(struct.b_factor.mean())
    return mean_plddt


def evaluate_model_once(
    num_sequences=3,
    seq_length=80,
    temperature=1.0,
):
    """
    For the *current* lm_model parameters, generate sequences,
    score each with ESMFold, and return the average score.
    """
    scores = []
    for _ in range(num_sequences):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=seq_length,
            temperature=temperature,
            device=device,
        )
        score = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)
        scores.append(score)
    return float(np.mean(scores)), scores

In [None]:
import torch.nn as nn

class EggrollContext:
    """
    Global context used by hooked Linear layers to decide
    whether to inject low-rank noise and which noise to use.
    """
    def __init__(self, device):
        self.active    = False  # enable/disable noise
        self.rank      = 0      # low-rank dimension r
        self.sigma     = 0.0    # noise scale
        self.thread_id = 0      # population member index
        self.device    = device

egg_ctx = EggrollContext(device)

# Hook all Linear layers in the ESM-2 model
lin_modules = []  # (name, module, base_key)

for name, module in lm_model.named_modules():
    if isinstance(module, nn.Linear):
        base_key = random.randint(0, 2**31 - 1)
        lin_modules.append((name, module, base_key))
        orig_forward = module.forward

        def make_forward(m, orig, base_key):
            def forward_with_low_rank(x, _orig=orig, _module=m, _base_key=base_key):
                # Base linear output
                out = _orig(x)
                if egg_ctx.active and _module.weight.ndim == 2:
                    weight = _module.weight
                    out_features, in_features = weight.shape

                    # Deterministic RNG for (layer, thread_id)
                    g = torch.Generator(device=egg_ctx.device)
                    combined_seed = (_base_key ^ egg_ctx.thread_id) & 0x7FFFFFFF
                    g.manual_seed(combined_seed)

                    r = egg_ctx.rank
                    if r > 0:
                        perturb = torch.randn(
                            in_features + out_features,
                            r,
                            generator=g,
                            device=egg_ctx.device,
                        )
                        B = perturb[:in_features, :]      # (in, r)
                        A = perturb[in_features:, :]      # (out, r)

                        # LoRA-style adapter: x @ B -> (batch, r); then @ A^T -> (batch, out)
                        adapter = x @ B
                        adapter = adapter @ A.t()
                        out = out + egg_ctx.sigma * adapter / math.sqrt(r)

                return out
            return forward_with_low_rank

        module.forward = make_forward(module, orig_forward, base_key)

print(f"Hooked {len(lin_modules)} Linear modules for EGGROLL-style noise.")

In [None]:
# Map parameter names to tensor objects
name_to_param = dict(lm_model.named_parameters())

# Track only the 2D Linear weight matrices that correspond to hooked modules
linear_weights = []  # (param_name, param_tensor, base_key)

for name, module, base_key in lin_modules:
    w_name = name + ".weight"
    if w_name in name_to_param:
        linear_weights.append((w_name, name_to_param[w_name], base_key))

print("Number of Linear weight matrices to optimize:", len(linear_weights))

In [None]:
@torch.no_grad()
def accumulate_linear_grad_buffers(
    grad_buffers,
    linear_weights,
    rank,
    sigma,
    thread_id,
    device,
    weight,
):
    """
    Re-generate the low-rank noise for each Linear weight and accumulate
    weight * sigma * ΔW into grad_buffers[name].
    """
    for param_name, param, base_key in linear_weights:
        out_features, in_features = param.data.shape

        g = torch.Generator(device=device)
        combined_seed = (base_key ^ thread_id) & 0x7FFFFFFF
        g.manual_seed(combined_seed)

        perturb = torch.randn(
            in_features + out_features,
            rank,
            generator=g,
            device=device,
        )
        B = perturb[:in_features, :]     # (in, r)
        A = perturb[in_features:, :]     # (out, r)
        delta = (A @ B.t()) / math.sqrt(rank)  # (out, in)

        if param_name not in grad_buffers:
            grad_buffers[param_name] = torch.zeros_like(param.data)
        grad_buffers[param_name].add_(weight * sigma * delta)


@torch.no_grad()
def es_step_eggroll(
    lm_model,
    linear_weights,
    rank,
    sigma,
    pop_size,
    lr,
    device,
    eval_fn,
    num_sequences,
    seq_length,
    temperature,
):
    """
    One EGGROLL-style ES step:
      - Evaluate pop_size perturbations via low-rank LoRA-like forward
      - Compute normalized rewards
      - Reconstruct perturbations to form a full-rank update
    """
    rewards = []

    # 1) Evaluate each population member
    for j in range(pop_size):
        egg_ctx.active    = True
        egg_ctx.rank      = rank
        egg_ctx.sigma     = sigma
        egg_ctx.thread_id = j

        mean_score, _ = eval_fn(
            num_sequences=num_sequences,
            seq_length=seq_length,
            temperature=temperature,
        )
        rewards.append(mean_score)

    # Turn off perturbations afterwards
    egg_ctx.active = False

    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)

    # 2) Normalize rewards (zero-mean, unit variance)
    if rewards_t.std() > 1e-8:
        norm_rewards = (rewards_t - rewards_t.mean()) / (rewards_t.std() + 1e-8)
    else:
        norm_rewards = rewards_t - rewards_t.mean()

    # 3) Accumulate gradient estimate in parameter space
    grad_buffers = {}
    for j, R in enumerate(norm_rewards):
        accumulate_linear_grad_buffers(
            grad_buffers,
            linear_weights,
            rank=rank,
            sigma=sigma,
            thread_id=j,
            device=device,
            weight=R.item(),
        )

    # 4) Apply update: θ_new = θ + lr * grad_estimate
    scale = lr / (pop_size * (sigma**2 + 1e-8))
    for param_name, param, base_key in linear_weights:
        param.data.add_(scale * grad_buffers[param_name])

    return float(rewards_t.mean().item()), float(rewards_t.max().item()), rewards

In [None]:
##########################################
# ES hyperparameters (Colab-friendly)
##########################################

SEQ_LENGTH      = 60       # amino acids per sequence
NUM_SEQUENCES   = 2        # sequences per population member
POP_SIZE        = 4        # ES population size
RANK            = 4        # low-rank dimension r
SIGMA           = 0.05     # noise scale
LR              = 0.1      # learning rate in parameter space
NUM_STEPS       = 3        # ES iterations (increase if GPU allows)
TEMPERATURE     = 1.0      # sampling temperature for LM

print("=== ES CONFIGURATION ===")
print(f"SEQ_LENGTH    = {SEQ_LENGTH}")
print(f"NUM_SEQUENCES = {NUM_SEQUENCES}")
print(f"POP_SIZE      = {POP_SIZE}")
print(f"RANK          = {RANK}")
print(f"SIGMA         = {SIGMA}")
print(f"LR            = {LR}")
print(f"NUM_STEPS     = {NUM_STEPS}")


def eval_fn(num_sequences=NUM_SEQUENCES, seq_length=SEQ_LENGTH, temperature=TEMPERATURE):
    return evaluate_model_once(
        num_sequences=num_sequences,
        seq_length=seq_length,
        temperature=temperature,
    )


##########################################
# Run ES
##########################################

history = []

for step in range(1, NUM_STEPS + 1):
    avg_reward, max_reward, all_rewards = es_step_eggroll(
        lm_model,
        linear_weights,
        rank=RANK,
        sigma=SIGMA,
        pop_size=POP_SIZE,
        lr=LR,
        device=device,
        eval_fn=eval_fn,
        num_sequences=NUM_SEQUENCES,
        seq_length=SEQ_LENGTH,
        temperature=TEMPERATURE,
    )
    history.append((avg_reward, max_reward))
    print(f"[Step {step}] avg reward = {avg_reward:.2f}, max reward = {max_reward:.2f}")

In [None]:
@torch.no_grad()
def sample_and_score_n(n=3, length=SEQ_LENGTH):
    seqs = []
    scores = []
    for _ in range(n):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=length,
            temperature=TEMPERATURE,
            device=device,
        )
        score = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)
        seqs.append(seq)
        scores.append(score)
    return seqs, scores


print("\nSampling sequences from the (possibly optimized) model...")
final_seqs, final_scores = sample_and_score_n(n=3, length=SEQ_LENGTH)

for i, (s, sc) in enumerate(zip(final_seqs, final_scores), 1):
    print(f"\nSequence {i} (mean pLDDT {sc:.2f}):\n{s}")

In [None]:
%%time
import numpy as np

##########################################
# 1) Measure a proper baseline
##########################################

# We'll use the same length we train on
TRAIN_SEQ_LENGTH = 60

baseline_mean, baseline_scores = evaluate_model_once(
    num_sequences=10,
    seq_length=TRAIN_SEQ_LENGTH,
    temperature=1.0,
)
print(f"[Baseline] mean pLDDT over 10 sequences: {baseline_mean:.2f}")
print("Individual scores:", [f"{s:.2f}" for s in baseline_scores])

##########################################
# 2) New ES hyperparameters (more stable)
##########################################

TRAIN_NUM_SEQUENCES = 4   # more sequences per member → less noisy rewards
TRAIN_POP_SIZE      = 8   # larger population if GPU allows
TRAIN_RANK          = 4   # low-rank dimension (keep small for Colab)
TRAIN_SIGMA         = 0.02
TRAIN_LR            = 0.03
TRAIN_NUM_STEPS     = 20  # more ES iterations
TRAIN_TEMPERATURE   = 1.0

print("\n=== NEW ES CONFIGURATION ===")
print(f"TRAIN_SEQ_LENGTH    = {TRAIN_SEQ_LENGTH}")
print(f"TRAIN_NUM_SEQUENCES = {TRAIN_NUM_SEQUENCES}")
print(f"TRAIN_POP_SIZE      = {TRAIN_POP_SIZE}")
print(f"TRAIN_RANK          = {TRAIN_RANK}")
print(f"TRAIN_SIGMA         = {TRAIN_SIGMA}")
print(f"TRAIN_LR            = {TRAIN_LR}")
print(f"TRAIN_NUM_STEPS     = {TRAIN_NUM_STEPS}")

##########################################
# 3) Restrict which layers we mutate
##########################################
# Instead of updating all 74 Linear weights, we only update the last few.
# This often stabilizes training and keeps the model closer to its pretrained prior.

NUM_LINEAR_TO_TRAIN = 24  # you can change this (e.g., 16, 32, etc.)

if len(linear_weights) <= NUM_LINEAR_TO_TRAIN:
    trained_linear_weights = linear_weights
    print(f"\nTraining ALL {len(trained_linear_weights)} Linear weights.")
else:
    trained_linear_weights = linear_weights[-NUM_LINEAR_TO_TRAIN:]
    print(f"\nTraining ONLY the last {NUM_LINEAR_TO_TRAIN} Linear weights out of {len(linear_weights)} total.")

# Small helper so es_step_eggroll can call evaluate_model_once with named args
def train_eval_fn(num_sequences, seq_length, temperature):
    return evaluate_model_once(
        num_sequences=num_sequences,
        seq_length=seq_length,
        temperature=temperature,
    )


In [None]:
%%time
history = []  # list of (avg_reward, max_reward)

print("\n=== Starting ES training ===")
for step in range(1, TRAIN_NUM_STEPS + 1):
    avg_reward, max_reward, all_rewards = es_step_eggroll(
        lm_model,
        trained_linear_weights,      # <--- restricted subset
        rank=TRAIN_RANK,
        sigma=TRAIN_SIGMA,
        pop_size=TRAIN_POP_SIZE,
        lr=TRAIN_LR,
        device=device,
        eval_fn=train_eval_fn,
        num_sequences=TRAIN_NUM_SEQUENCES,
        seq_length=TRAIN_SEQ_LENGTH,
        temperature=TRAIN_TEMPERATURE,
    )
    history.append((avg_reward, max_reward))
    print(
        f"[Step {step:02d}] "
        f"avg reward = {avg_reward:.2f}, "
        f"max reward = {max_reward:.2f}, "
        f"pop rewards = {[f'{r:.2f}' for r in all_rewards]}"
    )

##########################################
# Post-training evaluation (same as baseline)
##########################################
post_mean, post_scores = evaluate_model_once(
    num_sequences=10,
    seq_length=TRAIN_SEQ_LENGTH,
    temperature=1.0,
)
print("\n=== Baseline vs Post-training ===")
print(f"Baseline mean pLDDT over 10 seqs: {baseline_mean:.2f}")
print(f"Post-train mean pLDDT over 10 seqs: {post_mean:.2f}")
print("Post-train individual scores:", [f"{s:.2f}" for s in post_scores])


In [None]:
import matplotlib.pyplot as plt

##########################################
# Plot avg and max reward vs training step
##########################################

avg_hist = [a for (a, m) in history]
max_hist = [m for (a, m) in history]

plt.figure(figsize=(6,4))
plt.plot(avg_hist, label="avg reward")
plt.plot(max_hist, label="max reward")
plt.xlabel("ES step")
plt.ylabel("Reward (mean pLDDT)")
plt.title("ES reward during training")
plt.legend()
plt.grid(True)
plt.show()

##########################################
# Sample and score a few sequences from the final model
##########################################

@torch.no_grad()
def sample_and_score_n(n=3, length=TRAIN_SEQ_LENGTH):
    seqs = []
    scores = []
    for _ in range(n):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=length,
            temperature=TRAIN_TEMPERATURE,
            device=device,
        )
        score = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)
        seqs.append(seq)
        scores.append(score)
    return seqs, scores

print("\n=== Sampling from final (ES-updated) model ===")
final_seqs, final_scores = sample_and_score_n(n=3, length=TRAIN_SEQ_LENGTH)

for i, (s, sc) in enumerate(zip(final_seqs, final_scores), 1):
    print(f"\nSequence {i} (mean pLDDT {sc:.2f}):\n{s}")


In [None]:
%%time
import copy
import numpy as np
import matplotlib.pyplot as plt
import torch

##########################################
# 1. CONFIG (edit these to trade off runtime vs robustness)
##########################################

N_EXPERIMENTS          = 3      # number of ES runs with different seeds (e.g. 3–5)
BASELINE_N_SEQS        = 30     # sequences for baseline evaluation (30–50 recommended)
POST_N_SEQS            = 30     # sequences for post-training evaluation
EVAL_SEQ_LENGTH        = 60
EVAL_TEMPERATURE       = 1.0

# ES training hyperparameters PER EXPERIMENT
TRAIN_SEQ_LENGTH       = 60
TRAIN_NUM_SEQUENCES    = 4      # sequences per population member per step
TRAIN_POP_SIZE         = 8      # population size (try 8–16 if GPU allows)
TRAIN_RANK             = 4      # low-rank dimension
TRAIN_SIGMA            = 0.02
TRAIN_LR               = 0.03
TRAIN_NUM_STEPS        = 15     # steps per experiment (30 is stronger but slower)
TRAIN_TEMPERATURE      = 1.0

# Which subset of linear layers to train (last K)
NUM_LINEAR_TO_TRAIN    = 24

print("=== MULTI-RUN CONFIG ===")
print(f"N_EXPERIMENTS         = {N_EXPERIMENTS}")
print(f"BASELINE_N_SEQS       = {BASELINE_N_SEQS}")
print(f"POST_N_SEQS           = {POST_N_SEQS}")
print(f"TRAIN_POP_SIZE        = {TRAIN_POP_SIZE}")
print(f"TRAIN_NUM_STEPS       = {TRAIN_NUM_STEPS}")
print(f"TRAIN_NUM_SEQUENCES   = {TRAIN_NUM_SEQUENCES}")
print(f"NUM_LINEAR_TO_TRAIN   = {NUM_LINEAR_TO_TRAIN}")

##########################################
# 2. Choose subset of linear weights to train
##########################################

if len(linear_weights) <= NUM_LINEAR_TO_TRAIN:
    trained_linear_weights_global = linear_weights
    print(f"\nTraining ALL {len(trained_linear_weights_global)} Linear weights.")
else:
    trained_linear_weights_global = linear_weights[-NUM_LINEAR_TO_TRAIN:]
    print(f"\nTraining ONLY the last {NUM_LINEAR_TO_TRAIN} Linear weights out of {len(linear_weights)} total.")

##########################################
# 3. Save a base copy of the model to reset each experiment
##########################################

base_state_dict = {k: v.detach().cpu().clone() for k, v in lm_model.state_dict().items()}
print("\nSaved base model state for resetting between experiments.")


##########################################
# 4. Helper: evaluate model distribution (mean, std, tails)
##########################################

@torch.no_grad()
def evaluate_model_distribution(
    n_sequences: int,
    seq_length: int,
    temperature: float,
):
    """
    Generate n_sequences from the *current* lm_model, score with ESMFold,
    and return a dict with mean, std, and tail fractions.
    """
    scores = []
    for _ in range(n_sequences):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=seq_length,
            temperature=temperature,
            device=device,
        )
        score = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)
        scores.append(score)

    scores = np.array(scores, dtype=np.float32)
    mean_score = float(scores.mean())
    std_score  = float(scores.std(ddof=1)) if len(scores) > 1 else 0.0
    frac_gt_60 = float((scores > 60.0).mean())
    frac_lt_30 = float((scores < 30.0).mean())

    metrics = {
        "mean": mean_score,
        "std": std_score,
        "frac_gt_60": frac_gt_60,
        "frac_lt_30": frac_lt_30,
        "scores": scores,
    }
    return metrics


In [None]:
%%time
import random

def run_es_experiment(exp_id: int, seed: int):
    """
    Run one ES experiment:
      1) Reset model to base_state_dict
      2) Evaluate baseline distribution
      3) Run ES training
      4) Evaluate post-training distribution
    Returns a dict with baseline + post metrics and training history.
    """
    print(f"\n====================")
    print(f"Starting experiment {exp_id} with seed {seed}")
    print(f"====================")

    # --- Set seeds for reproducibility ---
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # --- Reset model to base state ---
    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    # --- Evaluate baseline distribution ---
    baseline_metrics = evaluate_model_distribution(
        n_sequences=BASELINE_N_SEQS,
        seq_length=EVAL_SEQ_LENGTH,
        temperature=EVAL_TEMPERATURE,
    )
    print(
        f"[Exp {exp_id}] Baseline mean pLDDT = {baseline_metrics['mean']:.2f} "
        f"+/- {baseline_metrics['std']:.2f}, "
        f"frac>60 = {baseline_metrics['frac_gt_60']:.2f}, "
        f"frac<30 = {baseline_metrics['frac_lt_30']:.2f}"
    )

    # --- ES training loop ---
    history = []  # (avg_reward, max_reward)
    def train_eval_fn(num_sequences, seq_length, temperature):
        return evaluate_model_once(
            num_sequences=num_sequences,
            seq_length=seq_length,
            temperature=temperature,
        )

    for step in range(1, TRAIN_NUM_STEPS + 1):
        avg_reward, max_reward, all_rewards = es_step_eggroll(
            lm_model,
            trained_linear_weights_global,
            rank=TRAIN_RANK,
            sigma=TRAIN_SIGMA,
            pop_size=TRAIN_POP_SIZE,
            lr=TRAIN_LR,
            device=device,
            eval_fn=train_eval_fn,
            num_sequences=TRAIN_NUM_SEQUENCES,
            seq_length=TRAIN_SEQ_LENGTH,
            temperature=TRAIN_TEMPERATURE,
        )
        history.append((avg_reward, max_reward))
        print(
            f"[Exp {exp_id} | Step {step:02d}] "
            f"avg = {avg_reward:.2f}, max = {max_reward:.2f}"
        )

    # --- Evaluate post-training distribution ---
    post_metrics = evaluate_model_distribution(
        n_sequences=POST_N_SEQS,
        seq_length=EVAL_SEQ_LENGTH,
        temperature=EVAL_TEMPERATURE,
    )
    print(
        f"[Exp {exp_id}] Post-train mean pLDDT = {post_metrics['mean']:.2f} "
        f"+/- {post_metrics['std']:.2f}, "
        f"frac>60 = {post_metrics['frac_gt_60']:.2f}, "
        f"frac<30 = {post_metrics['frac_lt_30']:.2f}"
    )

    return {
        "exp_id": exp_id,
        "seed": seed,
        "baseline": baseline_metrics,
        "post": post_metrics,
        "history": history,
    }


##########################################
# Run multiple experiments with different seeds
##########################################

experiment_results = []

base_seed = 12345  # change this if you like
for i in range(N_EXPERIMENTS):
    seed = base_seed + i
    res = run_es_experiment(exp_id=i, seed=seed)
    experiment_results.append(res)

print("\n=== Finished all experiments ===")


In [None]:
%%time
# Aggregate baseline vs post across experiments

baseline_means = [r["baseline"]["mean"] for r in experiment_results]
post_means     = [r["post"]["mean"] for r in experiment_results]

baseline_fracs_gt60 = [r["baseline"]["frac_gt_60"] for r in experiment_results]
post_fracs_gt60     = [r["post"]["frac_gt_60"] for r in experiment_results]

baseline_fracs_lt30 = [r["baseline"]["frac_lt_30"] for r in experiment_results]
post_fracs_lt30     = [r["post"]["frac_lt_30"] for r in experiment_results]

def summarize(arr):
    arr = np.array(arr, dtype=np.float32)
    return float(arr.mean()), float(arr.std(ddof=1) if len(arr) > 1 else 0.0)

b_mean, b_std = summarize(baseline_means)
p_mean, p_std = summarize(post_means)

b_gt60_mean, b_gt60_std = summarize(baseline_fracs_gt60)
p_gt60_mean, p_gt60_std = summarize(post_fracs_gt60)

b_lt30_mean, b_lt30_std = summarize(baseline_fracs_lt30)
p_lt30_mean, p_lt30_std = summarize(post_fracs_lt30)

print("=== Aggregate over experiments ===")
print(f"Baseline mean pLDDT: {b_mean:.2f} +/- {b_std:.2f}")
print(f"Post-train mean pLDDT: {p_mean:.2f} +/- {p_std:.2f}")
print()
print(f"Baseline frac>60: {b_gt60_mean:.2f} +/- {b_gt60_std:.2f}")
print(f"Post-train  frac>60: {p_gt60_mean:.2f} +/- {p_gt60_std:.2f}")
print()
print(f"Baseline frac<30: {b_lt30_mean:.2f} +/- {b_lt30_std:.2f}")
print(f"Post-train  frac<30: {p_lt30_mean:.2f} +/- {p_lt30_std:.2f}")

##########################################
# Optional: visualize distributions for one experiment
##########################################

exp_to_plot = 0  # index of experiment to inspect

baseline_scores = experiment_results[exp_to_plot]["baseline"]["scores"]
post_scores     = experiment_results[exp_to_plot]["post"]["scores"]

plt.figure(figsize=(7,4))
plt.hist(baseline_scores, bins=15, alpha=0.6, label="baseline")
plt.hist(post_scores, bins=15, alpha=0.6, label="post-train")
plt.xlabel("pLDDT")
plt.ylabel("Count")
plt.title(f"pLDDT distribution (experiment {exp_to_plot})")
plt.legend()
plt.grid(True)
plt.show()

##########################################
# Optional: plot ES reward curves for one experiment
##########################################

history0 = experiment_results[exp_to_plot]["history"]
avg_hist = [a for (a, m) in history0]
max_hist = [m for (a, m) in history0]

plt.figure(figsize=(6,4))
plt.plot(avg_hist, label="avg reward")
plt.plot(max_hist, label="max reward")
plt.xlabel("ES step")
plt.ylabel("Reward (mean pLDDT)")
plt.title(f"ES reward during training (experiment {exp_to_plot})")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
%%time
import copy
import random
import numpy as np
import torch

##########################################
# Early-stopping / checkpoint config
##########################################

ES_VALID_N_SEQS   = 10   # small validation sample size
ES_VALID_EVERY    = 2    # run validation every k ES steps
ES_PATIENCE       = 4    # stop if no improvement for this many validations
ES_MIN_STEPS      = 5    # always run at least this many steps before early stop

print("=== ES early-stopping CONFIG ===")
print(f"ES_VALID_N_SEQS = {ES_VALID_N_SEQS}")
print(f"ES_VALID_EVERY  = {ES_VALID_EVERY}")
print(f"ES_PATIENCE     = {ES_PATIENCE}")
print(f"ES_MIN_STEPS    = {ES_MIN_STEPS}")


def run_es_experiment_with_checkpoint(exp_id: int, seed: int):
    """
    ES experiment with:
      1) reset to base_state_dict
      2) baseline evaluation
      3) ES training with:
         - best-checkpoint tracking
         - optional early stopping
      4) post-training evaluation using BEST checkpoint
    Requires:
      - lm_model, device
      - base_state_dict
      - evaluate_model_distribution()
      - es_step_eggroll()
      - trained_linear_weights_global
      - TRAIN_* hyperparameters
    """
    print(f"\n====================")
    print(f"Starting ES experiment {exp_id} with seed {seed}")
    print(f"====================")

    # --- Set seeds for reproducibility ---
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # --- Reset model to base state ---
    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    # --- Evaluate baseline distribution ---
    baseline_metrics = evaluate_model_distribution(
        n_sequences=BASELINE_N_SEQS,
        seq_length=EVAL_SEQ_LENGTH,
        temperature=EVAL_TEMPERATURE,
    )
    print(
        f"[Exp {exp_id}] Baseline mean pLDDT = {baseline_metrics['mean']:.2f} "
        f"+/- {baseline_metrics['std']:.2f}, "
        f"frac>60 = {baseline_metrics['frac_gt_60']:.2f}, "
        f"frac<30 = {baseline_metrics['frac_lt_30']:.2f}"
    )

    # --- ES training loop with best checkpoint ---
    history = []  # (avg_reward, max_reward)
    best_valid_mean = -1e9
    best_state_dict = None
    best_step = 0
    n_valid_calls = 0

    def train_eval_fn(num_sequences, seq_length, temperature):
        return evaluate_model_once(
            num_sequences=num_sequences,
            seq_length=seq_length,
            temperature=temperature,
        )

    for step in range(1, TRAIN_NUM_STEPS + 1):
        avg_reward, max_reward, all_rewards = es_step_eggroll(
            lm_model,
            trained_linear_weights_global,
            rank=TRAIN_RANK,
            sigma=TRAIN_SIGMA,
            pop_size=TRAIN_POP_SIZE,
            lr=TRAIN_LR,
            device=device,
            eval_fn=train_eval_fn,
            num_sequences=TRAIN_NUM_SEQUENCES,
            seq_length=TRAIN_SEQ_LENGTH,
            temperature=TRAIN_TEMPERATURE,
        )
        history.append((avg_reward, max_reward))
        print(
            f"[Exp {exp_id} | Step {step:02d}] "
            f"avg = {avg_reward:.2f}, max = {max_reward:.2f}"
        )

        # ---- Validation + checkpointing ----
        do_valid = (step % ES_VALID_EVERY == 0) or (step == TRAIN_NUM_STEPS)
        if do_valid:
            n_valid_calls += 1
            valid_metrics = evaluate_model_distribution(
                n_sequences=ES_VALID_N_SEQS,
                seq_length=EVAL_SEQ_LENGTH,
                temperature=EVAL_TEMPERATURE,
            )
            valid_mean = valid_metrics["mean"]
            print(
                f"[Exp {exp_id} | Step {step:02d}] "
                f"VALID mean pLDDT = {valid_mean:.2f}"
            )

            # update best checkpoint
            if valid_mean > best_valid_mean:
                best_valid_mean = valid_mean
                best_state_dict = {
                    k: v.detach().cpu().clone() for k, v in lm_model.state_dict().items()
                }
                best_step = step
                print(
                    f"[Exp {exp_id}] New BEST checkpoint at step {step} "
                    f"(valid mean = {valid_mean:.2f})"
                )

            # early stopping: if we have run at least ES_MIN_STEPS and
            # no improvement for ES_PATIENCE validations
            if step >= ES_MIN_STEPS:
                steps_since_best = step - best_step
                # number of validations since best
                val_since_best = steps_since_best // ES_VALID_EVERY
                if val_since_best >= ES_PATIENCE:
                    print(
                        f"[Exp {exp_id}] Early stopping at step {step} "
                        f"(no improvement for {val_since_best} validations)."
                    )
                    break

    # If we never improved, fall back to final state
    if best_state_dict is None:
        print(f"[Exp {exp_id}] No improvement found; using final model.")
    else:
        print(
            f"[Exp {exp_id}] Loading BEST checkpoint from step {best_step} "
            f"(valid mean = {best_valid_mean:.2f})"
        )
        lm_model.load_state_dict(best_state_dict)
        lm_model.to(device)
        lm_model.eval()

    # --- Evaluate post-training distribution from BEST checkpoint ---
    post_metrics = evaluate_model_distribution(
        n_sequences=POST_N_SEQS,
        seq_length=EVAL_SEQ_LENGTH,
        temperature=EVAL_TEMPERATURE,
    )
    print(
        f"[Exp {exp_id}] Post-train mean pLDDT = {post_metrics['mean']:.2f} "
        f"+/- {post_metrics['std']:.2f}, "
        f"frac>60 = {post_metrics['frac_gt_60']:.2f}, "
        f"frac<30 = {post_metrics['frac_lt_30']:.2f}"
    )

    return {
        "exp_id": exp_id,
        "seed": seed,
        "baseline": baseline_metrics,
        "post": post_metrics,
        "history": history,
        "best_step": best_step,
        "best_valid_mean": best_valid_mean,
    }

In [None]:
%%time
experiment_results_es = []

base_seed = 12345
for i in range(N_EXPERIMENTS):
    seed = base_seed + i
    res = run_es_experiment_with_checkpoint(exp_id=i, seed=seed)
    experiment_results_es.append(res)

print("\n=== Finished ES experiments with checkpointing ===")

In [None]:
%%time
import random
import numpy as np

##########################################
# Sequence-space evolution baseline config
##########################################

SEQ_EVO_POP_SIZE      = TRAIN_POP_SIZE * TRAIN_NUM_SEQUENCES  # roughly match ES
SEQ_EVO_N_GENERATIONS = TRAIN_NUM_STEPS                       # roughly match ES
SEQ_EVO_N_PARENTS     = max(4, SEQ_EVO_POP_SIZE // 4)         # top fraction as parents
SEQ_EVO_MUT_RATE      = 0.05   # per-position mutation probability
SEQ_EVO_SEQ_LENGTH    = TRAIN_SEQ_LENGTH
SEQ_EVO_TEMPERATURE   = TRAIN_TEMPERATURE

print("=== Sequence-space evolution CONFIG ===")
print(f"POP_SIZE      = {SEQ_EVO_POP_SIZE}")
print(f"N_GENERATIONS = {SEQ_EVO_N_GENERATIONS}")
print(f"N_PARENTS     = {SEQ_EVO_N_PARENTS}")
print(f"MUT_RATE      = {SEQ_EVO_MUT_RATE}")
print(f"SEQ_LENGTH    = {SEQ_EVO_SEQ_LENGTH}")


def mutate_sequence(seq: str, mut_rate: float) -> str:
    """
    Simple point-mutation operator on amino acid sequences.
    Each position mutates with probability mut_rate to a random AA (uniform).
    """
    global AA_LETTERS
    aa_list = list(seq)
    for i in range(len(aa_list)):
        if random.random() < mut_rate:
            aa_list[i] = random.choice(AA_LETTERS)
    return "".join(aa_list)


@torch.no_grad()
def sequence_space_evolution_baseline(
    pop_size: int,
    n_generations: int,
    n_parents: int,
    seq_length: int,
    temperature: float,
    mut_rate: float,
    seed: int = 0,
):
    """
    Generic sequence-space evolution baseline:
      - Initialize population from the *base* lm_model (base_state_dict)
      - For n_generations:
          * evaluate with ESMFold
          * select top n_parents
          * generate new population by mutating parents
      - Track history of (avg_score, max_score) and best sequence found.
    """
    print(f"\n=== Running sequence-space evolution baseline (seed={seed}) ===")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Ensure lm_model is in base state (fairness)
    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    # --- Initialize population from base LM ---
    population = []
    for _ in range(pop_size):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=seq_length,
            temperature=temperature,
            device=device,
        )
        population.append(seq)

    best_seq = None
    best_score = -1e9
    history = []  # (avg_score, max_score)

    for gen in range(n_generations):
        scores = []
        for seq in population:
            sc = evaluate_sequence_with_esmfold(
                seq, fold_model=fold_model, device=device
            )
            scores.append(sc)

        scores = np.array(scores, dtype=np.float32)
        avg_sc = float(scores.mean())
        max_sc = float(scores.max())
        history.append((avg_sc, max_sc))

        print(
            f"[SeqEvo | Gen {gen:02d}] "
            f"avg = {avg_sc:.2f}, max = {max_sc:.2f}"
        )

        # Update global best
        best_idx = int(scores.argmax())
        if max_sc > best_score:
            best_score = max_sc
            best_seq = population[best_idx]
            print(f"[SeqEvo] New best sequence with pLDDT = {best_score:.2f}")

        # Last generation: no need to create next population
        if gen == n_generations - 1:
            break

        # --- Select parents & generate next-generation population ---
        parent_indices = scores.argsort()[::-1][:n_parents]
        parents = [population[i] for i in parent_indices]

        new_population = []

        # Elitism: carry over the best parent unchanged
        new_population.append(parents[0])

        # Fill the rest with mutated children
        while len(new_population) < pop_size:
            parent = random.choice(parents)
            child = mutate_sequence(parent, mut_rate)
            new_population.append(child)

        population = new_population

    return {
        "history": history,
        "best_seq": best_seq,
        "best_score": best_score,
        "final_population": population,
    }

In [None]:
%%time
seq_evo_result = sequence_space_evolution_baseline(
    pop_size=SEQ_EVO_POP_SIZE,
    n_generations=SEQ_EVO_N_GENERATIONS,
    n_parents=SEQ_EVO_N_PARENTS,
    seq_length=SEQ_EVO_SEQ_LENGTH,
    temperature=SEQ_EVO_TEMPERATURE,
    mut_rate=SEQ_EVO_MUT_RATE,
    seed=999,
)

print("\n=== Sequence-space evolution summary ===")
print(f"Best pLDDT found: {seq_evo_result['best_score']:.2f}")
print(f"Best sequence:\n{seq_evo_result['best_seq']}")

In [None]:
import matplotlib.pyplot as plt

# Example: compare ES experiment 0 vs sequence-space baseline
es_hist = experiment_results_es[0]["history"]  # (avg, max) per ES step
es_avg = [a for (a, m) in es_hist]
es_max = [m for (a, m) in es_hist]

seq_hist = seq_evo_result["history"]
seq_avg = [a for (a, m) in seq_hist]
seq_max = [m for (a, m) in seq_hist]

plt.figure(figsize=(7,4))
plt.plot(es_avg, label="ES avg")
plt.plot(es_max, label="ES max")
plt.plot(seq_avg, label="SeqEvo avg", linestyle="--")
plt.plot(seq_max, label="SeqEvo max", linestyle="--")
plt.xlabel("Step / Generation")
plt.ylabel("pLDDT")
plt.title("ES vs sequence-space evolution (single seed)")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
%%time
"""
Post-ES utilities:
- Multi-objective reward (pLDDT + LM prior)
- ESMFold compute accounting for ES vs SeqEvo
- Multi-seed SeqEvo baseline
- Diversity analysis (sequence identity & entropy)
- Example small ES vs SeqEvo comparison using the new objective

This cell assumes you already defined:
  - lm_model, alphabet, AA_LETTERS, device
  - sample_sequence_from_lm(...)
  - evaluate_sequence_with_esmfold(...)
  - es_step_eggroll(...)
  - evaluate_model_distribution(...)
  - trained_linear_weights_global
  - base_state_dict
"""

import math
import numpy as np
import torch
import random
from collections import Counter

##########################################
# 1. Global counters for ESMFold calls
##########################################

ESMFOLD_CALLS_ES      = 0  # calls used by ES experiments
ESMFOLD_CALLS_SEQEVO  = 0  # calls used by SeqEvo experiments

def reset_esmfold_counters():
    global ESMFOLD_CALLS_ES, ESMFOLD_CALLS_SEQEVO
    ESMFOLD_CALLS_ES = 0
    ESMFOLD_CALLS_SEQEVO = 0

print("Initialized ESMFold call counters.")


##########################################
# 2. LM pseudo-likelihood (prior term)
##########################################

@torch.no_grad()
def lm_pseudo_log_prob_per_residue(seq: str) -> float:
    """
    Compute a (rough) pseudo log-likelihood per residue for a sequence under ESM-2.

    ESM-2 is a masked LM, so we approximate:
      log p(x) ≈ sum_i log p(x_i | x_{-i})
    by masking one position at a time and reading off the probability of the
    true amino acid.

    Returns natural-log average per residue (higher is better).
    """
    tokens = torch.full(
        (1, len(seq) + 2),
        fill_value=alphabet.mask_idx,
        device=device,
        dtype=torch.long,
    )
    tokens[0, 0] = alphabet.cls_idx
    tokens[0, -1] = alphabet.eos_idx
    # fill true sequence
    for i, aa in enumerate(seq, start=1):
        tokens[0, i] = alphabet.get_idx(aa)

    log_probs = []
    for pos in range(1, len(seq) + 1):
        orig = tokens[0, pos].item()
        tokens[0, pos] = alphabet.mask_idx
        out = lm_model(tokens, repr_layers=[], return_contacts=False)
        logits = out["logits"][0, pos]  # vocab
        logp = torch.log_softmax(logits, dim=-1)[orig]
        log_probs.append(float(logp))
        tokens[0, pos] = orig

    return float(np.mean(log_probs))


##########################################
# 3. Multi-objective reward
##########################################

def compute_multiobjective_reward(
    seq: str,
    weight_lm_prior: float = 0.1,
    caller: str = "ES",
):
    """
    Composite reward:
        R(seq) = pLDDT(seq)  + weight_lm_prior * normalized_LM_prior(seq)

    where normalized_LM_prior is log p(x) / log(1/20) so that a uniform model ~0
    and a better-than-random model >0.

    caller: "ES" or "SEQEVO" -> used to increment correct ESMFold counter.
    """
    global ESMFOLD_CALLS_ES, ESMFOLD_CALLS_SEQEVO

    # 1) structural term: mean pLDDT from ESMFold
    plddt = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)

    if caller.upper() == "ES":
        ESMFOLD_CALLS_ES += 1
    elif caller.upper() == "SEQEVO":
        ESMFOLD_CALLS_SEQEVO += 1

    # 2) LM prior term
    lm_logp = lm_pseudo_log_prob_per_residue(seq)
    # normalize vs uniform AA distribution: log(1/20) = -log(20)
    norm_lm = lm_logp / (-math.log(1.0 / 20.0) + 1e-8)

    reward = float(plddt + weight_lm_prior * norm_lm)

    return reward, plddt, lm_logp, norm_lm


##########################################
# 4. ES training with the new objective & compute tracking
##########################################

def evaluate_model_once_objective(
    num_sequences: int,
    seq_length: int,
    temperature: float,
    weight_lm_prior: float = 0.1,
):
    """
    For current parameter setting of lm_model, generate num_sequences sequences,
    evaluate composite reward for each, and return (avg_reward, list_rewards).
    """
    rewards = []
    for _ in range(num_sequences):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=seq_length,
            temperature=temperature,
            device=device,
        )
        r, plddt, lm_logp, norm_lm = compute_multiobjective_reward(
            seq,
            weight_lm_prior=weight_lm_prior,
            caller="ES",
        )
        rewards.append(r)
    return float(np.mean(rewards)), rewards


def run_es_multiobjective_experiment(
    exp_id: int,
    seed: int,
    train_steps: int,
    train_pop_size: int,
    train_num_sequences: int,
    train_seq_length: int,
    train_temperature: float,
    train_sigma: float,
    train_lr: float,
    weight_lm_prior: float = 0.1,
):
    """
    Single ES experiment with the new composite objective.
    Uses the existing es_step_eggroll but plugs in evaluate_model_once_objective.
    """
    print(f"\n====================")
    print(f"Starting ES multi-objective experiment {exp_id} (seed={seed})")
    print(f"====================")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # reset model to base state
    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    history = []

    def train_eval_fn(num_sequences, seq_length, temperature):
        return evaluate_model_once_objective(
            num_sequences=num_sequences,
            seq_length=seq_length,
            temperature=temperature,
            weight_lm_prior=weight_lm_prior,
        )

    for step in range(1, train_steps + 1):
        avg_reward, max_reward, all_rewards = es_step_eggroll(
            lm_model,
            trained_linear_weights_global,
            rank=TRAIN_RANK,
            sigma=train_sigma,
            pop_size=train_pop_size,
            lr=train_lr,
            device=device,
            eval_fn=train_eval_fn,
            num_sequences=train_num_sequences,
            seq_length=train_seq_length,
            temperature=train_temperature,
        )
        history.append((avg_reward, max_reward))
        print(
            f"[ES-Obj Exp {exp_id} | Step {step:02d}] "
            f"avg_reward={avg_reward:.2f}, max_reward={max_reward:.2f}"
        )

    # after training, sample some sequences and measure pLDDT & LM prior separately
    eval_metrics = evaluate_model_distribution(
        n_sequences=30,
        seq_length=train_seq_length,
        temperature=train_temperature,
    )

    return {
        "exp_id": exp_id,
        "seed": seed,
        "history": history,
        "eval_metrics": eval_metrics,
    }


##########################################
# 5. SeqEvo baseline with the same objective & compute tracking
##########################################

def mutate_sequence_with_mask(seq: str, mut_rate: float, mutable_mask=None) -> str:
    """
    Point mutation with optional mask:
      - If mutable_mask is None: any position can mutate.
      - Else: only positions with mutable_mask[i] == True can mutate.
    """
    global AA_LETTERS
    aa_list = list(seq)
    L = len(aa_list)
    if mutable_mask is None:
        mutable_mask = [True] * L

    for i in range(L):
        if not mutable_mask[i]:
            continue
        if random.random() < mut_rate:
            aa_list[i] = random.choice(AA_LETTERS)
    return "".join(aa_list)


@torch.no_grad()
def sequence_space_evolution_objective(
    exp_id: int,
    seed: int,
    pop_size: int,
    n_generations: int,
    n_parents: int,
    seq_length: int,
    temperature: float,
    mut_rate: float,
    weight_lm_prior: float = 0.1,
    mutable_mask=None,
):
    """
    SeqEvo baseline using the same composite reward as ES.

    mutable_mask: optional list[bool] of length seq_length;
                  False positions are frozen (scaffold), only others mutate.
    """
    print(f"\n=== SeqEvo multi-objective experiment {exp_id} (seed={seed}) ===")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # ensure base LM (for fairness)
    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    # Initialize population
    population = []
    for _ in range(pop_size):
        seq = sample_sequence_from_lm(
            lm_model,
            alphabet,
            length=seq_length,
            temperature=temperature,
            device=device,
        )
        population.append(seq)

    best_seq = None
    best_reward = -1e9
    best_plddt = -1e9
    history = []

    for gen in range(n_generations):
        rewards = []
        plddts = []
        for seq in population:
            r, plddt, lm_logp, norm_lm = compute_multiobjective_reward(
                seq,
                weight_lm_prior=weight_lm_prior,
                caller="SEQEVO",
            )
            rewards.append(r)
            plddts.append(plddt)

        rewards = np.array(rewards, dtype=np.float32)
        plddts = np.array(plddts, dtype=np.float32)

        avg_r = float(rewards.mean())
        max_r = float(rewards.max())
        history.append((avg_r, max_r))

        print(
            f"[SeqEvo-Obj Exp {exp_id} | Gen {gen:02d}] "
            f"avg_reward={avg_r:.2f}, max_reward={max_r:.2f}, "
            f"avg_plddt={float(plddts.mean()):.2f}"
        )

        # update best
        idx_best = int(rewards.argmax())
        if rewards[idx_best] > best_reward:
            best_reward = float(rewards[idx_best])
            best_plddt = float(plddts[idx_best])
            best_seq = population[idx_best]
            print(
                f"[SeqEvo-Obj Exp {exp_id}] New best: "
                f"reward={best_reward:.2f}, pLDDT={best_plddt:.2f}"
            )

        if gen == n_generations - 1:
            break

        # parent selection
        parent_indices = rewards.argsort()[::-1][:n_parents]
        parents = [population[i] for i in parent_indices]

        new_population = []
        # elitism
        new_population.append(parents[0])

        while len(new_population) < pop_size:
            parent = random.choice(parents)
            child = mutate_sequence_with_mask(parent, mut_rate, mutable_mask)
            new_population.append(child)

        population = new_population

    return {
        "exp_id": exp_id,
        "seed": seed,
        "history": history,
        "best_seq": best_seq,
        "best_reward": best_reward,
        "best_plddt": best_plddt,
        "final_population": population,
    }


##########################################
# 6. Diversity analysis
##########################################

def sequence_identity(s1: str, s2: str) -> float:
    """Fraction of identical positions between two equal-length sequences."""
    assert len(s1) == len(s2)
    same = sum(a == b for a, b in zip(s1, s2))
    return same / len(s1)


def analyze_sequence_set(seqs, label="set"):
    """
    Compute simple diversity metrics:
      - mean pairwise identity
      - per-position entropy
    """
    if len(seqs) == 0:
        print(f"[Diversity] {label}: no sequences.")
        return {}

    L = len(seqs[0])
    seqs = [s for s in seqs if len(s) == L]
    n = len(seqs)

    # pairwise identities
    if n > 1:
        ids = []
        for i in range(n):
            for j in range(i + 1, n):
                ids.append(sequence_identity(seqs[i], seqs[j]))
        mean_id = float(np.mean(ids))
    else:
        mean_id = float("nan")

    # per-position entropy
    entropies = []
    for pos in range(L):
        column = [s[pos] for s in seqs]
        counts = Counter(column)
        total = sum(counts.values())
        probs = np.array([c / total for c in counts.values()], dtype=np.float32)
        ent = -float((probs * np.log2(probs + 1e-12)).sum())
        entropies.append(ent)

    mean_entropy = float(np.mean(entropies))

    print(
        f"[Diversity] {label}: n={n}, "
        f"mean pairwise identity={mean_id:.3f}, "
        f"mean per-pos entropy={mean_entropy:.3f}"
    )

    return {
        "n": n,
        "mean_identity": mean_id,
        "mean_entropy": mean_entropy,
        "entropies": entropies,
    }


##########################################
# 7. Example small-scale ES vs SeqEvo objective run
##########################################

# You can scale these up later; these are modest so it doesn't explode compute.
OBJ_ES_N_EXPERIMENTS      = 1
OBJ_ES_TRAIN_STEPS        = 5
OBJ_ES_POP_SIZE           = 4
OBJ_ES_NUM_SEQUENCES      = 2
OBJ_ES_SEQ_LENGTH         = 60
OBJ_ES_TEMPERATURE        = 1.0
OBJ_ES_SIGMA              = TRAIN_SIGMA    # reuse your ES sigma
OBJ_ES_LR                 = TRAIN_LR       # reuse your ES lr
OBJ_WEIGHT_LM_PRIOR       = 0.1

OBJ_SEQEVO_N_EXPERIMENTS  = 1
OBJ_SEQEVO_POP_SIZE       = OBJ_ES_POP_SIZE * OBJ_ES_NUM_SEQUENCES
OBJ_SEQEVO_GENERATIONS    = OBJ_ES_TRAIN_STEPS
OBJ_SEQEVO_N_PARENTS      = max(4, OBJ_SEQEVO_POP_SIZE // 4)
OBJ_SEQEVO_MUT_RATE       = 0.05
OBJ_SEQEVO_SEQ_LENGTH     = OBJ_ES_SEQ_LENGTH
OBJ_SEQEVO_TEMPERATURE    = OBJ_ES_TEMPERATURE

print("\n=== Multi-objective small run CONFIG ===")
print(f"ES: {OBJ_ES_N_EXPERIMENTS} exp, steps={OBJ_ES_TRAIN_STEPS}, pop={OBJ_ES_POP_SIZE}")
print(f"SeqEvo: {OBJ_SEQEVO_N_EXPERIMENTS} exp, gens={OBJ_SEQEVO_GENERATIONS}, pop={OBJ_SEQEVO_POP_SIZE}")

reset_esmfold_counters()

# ---- Run ES multi-objective (small) ----
es_obj_results = []
base_seed = 2025
for i in range(OBJ_ES_N_EXPERIMENTS):
    res = run_es_multiobjective_experiment(
        exp_id=i,
        seed=base_seed + i,
        train_steps=OBJ_ES_TRAIN_STEPS,
        train_pop_size=OBJ_ES_POP_SIZE,
        train_num_sequences=OBJ_ES_NUM_SEQUENCES,
        train_seq_length=OBJ_ES_SEQ_LENGTH,
        train_temperature=OBJ_ES_TEMPERATURE,
        train_sigma=OBJ_ES_SIGMA,
        train_lr=OBJ_ES_LR,
        weight_lm_prior=OBJ_WEIGHT_LM_PRIOR,
    )
    es_obj_results.append(res)

# sample from final ES model of last experiment for diversity
es_sample_seqs = []
for _ in range(32):
    s = sample_sequence_from_lm(
        lm_model,
        alphabet,
        length=OBJ_ES_SEQ_LENGTH,
        temperature=OBJ_ES_TEMPERATURE,
        device=device,
    )
    es_sample_seqs.append(s)

# ---- Run SeqEvo multi-objective (small) ----
seqevo_obj_results = []
for i in range(OBJ_SEQEVO_N_EXPERIMENTS):
    res = sequence_space_evolution_objective(
        exp_id=i,
        seed=base_seed + 100 + i,
        pop_size=OBJ_SEQEVO_POP_SIZE,
        n_generations=OBJ_SEQEVO_GENERATIONS,
        n_parents=OBJ_SEQEVO_N_PARENTS,
        seq_length=OBJ_SEQEVO_SEQ_LENGTH,
        temperature=OBJ_SEQEVO_TEMPERATURE,
        mut_rate=OBJ_SEQEVO_MUT_RATE,
        weight_lm_prior=OBJ_WEIGHT_LM_PRIOR,
        mutable_mask=None,  # could freeze scaffold here
    )
    seqevo_obj_results.append(res)

seqevo_final_pop = seqevo_obj_results[0]["final_population"]

print("\n=== ESMFold compute summary ===")
print(f"ES ESMFold calls:      {ESMFOLD_CALLS_ES}")
print(f"SeqEvo ESMFold calls:  {ESMFOLD_CALLS_SEQEVO}")

# ---- Diversity comparison ----
div_es = analyze_sequence_set(es_sample_seqs, label="ES-final-samples")
div_seqevo = analyze_sequence_set(seqevo_final_pop, label="SeqEvo-final-population")

# ---- Simple plot of reward curves (single seed) ----
import matplotlib.pyplot as plt

es_hist = es_obj_results[0]["history"]
es_avg = [a for (a, m) in es_hist]
es_max = [m for (a, m) in es_hist]

seq_hist = seqevo_obj_results[0]["history"]
seq_avg = [a for (a, m) in seq_hist]
seq_max = [m for (a, m) in seq_hist]

plt.figure(figsize=(7,4))
plt.plot(es_avg, label="ES avg")
plt.plot(es_max, label="ES max")
plt.plot(seq_avg, label="SeqEvo avg", linestyle="--")
plt.plot(seq_max, label="SeqEvo max", linestyle="--")
plt.xlabel("Step / Generation")
plt.ylabel("Multi-objective reward")
plt.title("ES vs sequence-space evolution (multi-objective, small run)")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
%%time
"""
Constrained, multi-objective ES vs SeqEvo with multiple seeds.

Adds:
  - richer composite reward (pLDDT + LM prior - hydrophobic penalty)
  - constrained loop regime (fixed scaffold, mutable internal region)
  - multi-seed ES vs SeqEvo comparison with matched ESMFold training budget
  - summary of rewards and diversity across seeds

ASSUMES the previous cell has already defined:
  - lm_model, alphabet, AA_LETTERS, device
  - sample_sequence_from_lm(...)
  - evaluate_sequence_with_esmfold(...)
  - ESMFold call counters: ESMFOLD_CALLS_ES, ESMFOLD_CALLS_SEQEVO, reset_esmfold_counters()
  - diversity helpers: analyze_sequence_set(...)
  - ES machinery: es_step_eggroll(...), trained_linear_weights_global, TRAIN_RANK, TRAIN_SIGMA, TRAIN_LR
  - base_state_dict (saved initial LM weights)
"""

import math
import numpy as np
import torch
import random
from collections import Counter

##########################################
# 1. Updated multi-objective reward
##########################################

HYDROPHOBIC_AA = set("AILMFWVY")


def hydrophobic_fraction(seq: str) -> float:
    return sum(aa in HYDROPHOBIC_AA for aa in seq) / max(len(seq), 1)


@torch.no_grad()
def lm_pseudo_log_prob_per_residue(seq: str) -> float:
    """
    Pseudo log-likelihood per residue for a sequence under ESM-2.

    Same as in the previous cell, but repeated here to keep this block self-contained.
    """
    tokens = torch.full(
        (1, len(seq) + 2),
        fill_value=alphabet.mask_idx,
        device=device,
        dtype=torch.long,
    )
    tokens[0, 0] = alphabet.cls_idx
    tokens[0, -1] = alphabet.eos_idx
    for i, aa in enumerate(seq, start=1):
        tokens[0, i] = alphabet.get_idx(aa)

    log_probs = []
    for pos in range(1, len(seq) + 1):
        orig = tokens[0, pos].item()
        tokens[0, pos] = alphabet.mask_idx
        out = lm_model(tokens, repr_layers=[], return_contacts=False)
        logits = out["logits"][0, pos]
        logp = torch.log_softmax(logits, dim=-1)[orig]
        log_probs.append(float(logp))
        tokens[0, pos] = orig

    return float(np.mean(log_probs))


def compute_multiobjective_reward(
    seq: str,
    weight_lm_prior: float = 0.1,
    target_hydro: float = 0.50,
    lambda_hydro: float = 20.0,
    caller: str = "ES",
):
    """
    Composite reward used for *both* ES and SeqEvo:

      R(seq) =
          pLDDT(seq)
        + weight_lm_prior * normalized_LM_prior(seq)
        - lambda_hydro * |hydrophobic_fraction(seq) - target_hydro|

    where normalized_LM_prior ≈ log p(x) / log(1/20), so that
    a uniform random model ≈ 0, better models > 0.

    We also increment global ESMFold call counters for compute accounting.
    """
    global ESMFOLD_CALLS_ES, ESMFOLD_CALLS_SEQEVO

    # 1) ESMFold term
    plddt = evaluate_sequence_with_esmfold(seq, fold_model=fold_model, device=device)
    if caller.upper() == "ES":
        ESMFOLD_CALLS_ES += 1
    elif caller.upper() == "SEQEVO":
        ESMFOLD_CALLS_SEQEVO += 1

    # 2) LM prior
    lm_logp = lm_pseudo_log_prob_per_residue(seq)
    norm_lm = lm_logp / (-math.log(1.0 / 20.0) + 1e-8)

    # 3) hydrophobicity penalty
    h = hydrophobic_fraction(seq)
    hydro_pen = lambda_hydro * abs(h - target_hydro)

    reward = float(plddt + weight_lm_prior * norm_lm - hydro_pen)
    return reward, plddt, lm_logp, norm_lm, h


##########################################
# 2. Constrained sequence generator (fixed scaffold, mutable loop)
##########################################

@torch.no_grad()
def sample_sequence_from_lm_constrained(
    model,
    alphabet,
    base_seq: str,
    mutable_mask,
    temperature: float = 1.0,
    device=device,
):
    """
    Sample a sequence of same length as base_seq, keeping positions with
    mutable_mask[i] == False fixed to base_seq[i], and resampling the others.

    Uses the same "mask a single position, run LM" trick as sample_sequence_from_lm.
    """
    L = len(base_seq)
    assert len(mutable_mask) == L

    tokens = torch.full(
        (1, L + 2),
        fill_value=alphabet.mask_idx,
        device=device,
        dtype=torch.long,
    )
    tokens[0, 0] = alphabet.cls_idx
    tokens[0, -1] = alphabet.eos_idx

    # initialize tokens to base_seq
    for i, aa in enumerate(base_seq, start=1):
        tokens[0, i] = alphabet.get_idx(aa)

    for pos in range(1, L + 1):
        if not mutable_mask[pos - 1]:
            # keep scaffold amino acid
            continue

        tokens[0, pos] = alphabet.mask_idx
        out = model(tokens, repr_layers=[], return_contacts=False)
        logits = out["logits"][0, pos]

        aa_logits = logits[AA_INDICES]
        probs = torch.softmax(aa_logits / temperature, dim=-1)
        aa_idx = torch.multinomial(probs, num_samples=1)
        tok_id = AA_INDICES[aa_idx]
        tokens[0, pos] = tok_id

    seq_tokens = tokens[0, 1:-1].tolist()
    seq = "".join(alphabet.get_tok(t) for t in seq_tokens)
    return seq


##########################################
# 3. ES multi-objective experiment (with constraints)
##########################################

def evaluate_model_once_objective(
    num_sequences: int,
    seq_length: int,
    temperature: float,
    weight_lm_prior: float,
    target_hydro: float,
    lambda_hydro: float,
    base_seq: str = None,
    mutable_mask=None,
):
    rewards = []
    for _ in range(num_sequences):
        if base_seq is None or mutable_mask is None:
            seq = sample_sequence_from_lm(
                lm_model,
                alphabet,
                length=seq_length,
                temperature=temperature,
                device=device,
            )
        else:
            seq = sample_sequence_from_lm_constrained(
                lm_model,
                alphabet,
                base_seq=base_seq,
                mutable_mask=mutable_mask,
                temperature=temperature,
                device=device,
            )

        r, plddt, lm_logp, norm_lm, h = compute_multiobjective_reward(
            seq,
            weight_lm_prior=weight_lm_prior,
            target_hydro=target_hydro,
            lambda_hydro=lambda_hydro,
            caller="ES",
        )
        rewards.append(r)

    return float(np.mean(rewards)), rewards


def run_es_multiobjective_experiment_constrained(
    exp_id: int,
    seed: int,
    train_steps: int,
    train_pop_size: int,
    train_num_sequences: int,
    seq_length: int,
    temperature: float,
    sigma: float,
    lr: float,
    weight_lm_prior: float,
    target_hydro: float,
    lambda_hydro: float,
    base_seq: str,
    mutable_mask,
):
    """
    Same as run_es_multiobjective_experiment, but:
      - uses constrained sampling around base_seq with mutable_mask
      - returns best training reward + history
    """
    print(f"\n==============================")
    print(f"ES constrained exp {exp_id} (seed={seed})")
    print(f"==============================")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    history = []
    best_training_reward = -1e9

    def train_eval_fn(num_sequences, seq_length, temperature):
        avg_r, rewards = evaluate_model_once_objective(
            num_sequences=num_sequences,
            seq_length=seq_length,
            temperature=temperature,
            weight_lm_prior=weight_lm_prior,
            target_hydro=target_hydro,
            lambda_hydro=lambda_hydro,
            base_seq=base_seq,
            mutable_mask=mutable_mask,
        )
        return avg_r, rewards

    for step in range(1, train_steps + 1):
        avg_reward, max_reward, all_rewards = es_step_eggroll(
            lm_model,
            trained_linear_weights_global,
            rank=TRAIN_RANK,
            sigma=sigma,
            pop_size=train_pop_size,
            lr=lr,
            device=device,
            eval_fn=train_eval_fn,
            num_sequences=train_num_sequences,
            seq_length=seq_length,
            temperature=temperature,
        )
        history.append((avg_reward, max_reward))
        best_training_reward = max(best_training_reward, max_reward)
        print(
            f"[ES-constr Exp {exp_id} | Step {step:02d}] "
            f"avg={avg_reward:.2f}, max={max_reward:.2f}"
        )

    # sample some sequences from final ES-tuned model for diversity analysis
    es_final_seqs = []
    for _ in range(32):
        s = sample_sequence_from_lm_constrained(
            lm_model,
            alphabet,
            base_seq=base_seq,
            mutable_mask=mutable_mask,
            temperature=temperature,
            device=device,
        )
        es_final_seqs.append(s)

    return {
        "exp_id": exp_id,
        "seed": seed,
        "history": history,
        "best_training_reward": best_training_reward,
        "final_seqs": es_final_seqs,
    }


##########################################
# 4. SeqEvo multi-objective, constrained
##########################################

def mutate_sequence_with_mask(seq: str, mut_rate: float, mutable_mask=None) -> str:
    aa_list = list(seq)
    L = len(aa_list)
    if mutable_mask is None:
        mutable_mask = [True] * L
    assert len(mutable_mask) == L

    for i in range(L):
        if not mutable_mask[i]:
            continue
        if random.random() < mut_rate:
            aa_list[i] = random.choice(AA_LETTERS)
    return "".join(aa_list)


def sequence_space_evolution_objective_constrained(
    exp_id: int,
    seed: int,
    pop_size: int,
    n_generations: int,
    n_parents: int,
    seq_length: int,
    temperature: float,
    mut_rate: float,
    weight_lm_prior: float,
    target_hydro: float,
    lambda_hydro: float,
    base_seq: str,
    mutable_mask,
):
    """
    SeqEvo baseline with:
      - composite reward (same as ES)
      - constrained mutations (only mutable_mask positions can change)
      - initial population generated by constrained LM sampling
    """
    print(f"\n=== SeqEvo constrained exp {exp_id} (seed={seed}) ===")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    lm_model.load_state_dict(base_state_dict)
    lm_model.to(device)
    lm_model.eval()

    # initial population
    population = []
    for _ in range(pop_size):
        s = sample_sequence_from_lm_constrained(
            lm_model,
            alphabet,
            base_seq=base_seq,
            mutable_mask=mutable_mask,
            temperature=temperature,
            device=device,
        )
        population.append(s)

    best_reward = -1e9
    best_plddt = -1e9
    best_seq = None
    history = []

    for gen in range(n_generations):
        rewards = []
        plddts = []
        for seq in population:
            r, plddt, lm_logp, norm_lm, h = compute_multiobjective_reward(
                seq,
                weight_lm_prior=weight_lm_prior,
                target_hydro=target_hydro,
                lambda_hydro=lambda_hydro,
                caller="SEQEVO",
            )
            rewards.append(r)
            plddts.append(plddt)

        rewards = np.array(rewards, dtype=np.float32)
        plddts = np.array(plddts, dtype=np.float32)

        avg_r = float(rewards.mean())
        max_r = float(rewards.max())
        history.append((avg_r, max_r))

        print(
            f"[SeqEvo-constr Exp {exp_id} | Gen {gen:02d}] "
            f"avg={avg_r:.2f}, max={max_r:.2f}, avg_pLDDT={float(plddts.mean()):.2f}"
        )

        idx_best = int(rewards.argmax())
        if rewards[idx_best] > best_reward:
            best_reward = float(rewards[idx_best])
            best_plddt = float(plddts[idx_best])
            best_seq = population[idx_best]
            print(
                f"[SeqEvo-constr Exp {exp_id}] "
                f"New best: reward={best_reward:.2f}, pLDDT={best_plddt:.2f}"
            )

        if gen == n_generations - 1:
            break

        parent_indices = rewards.argsort()[::-1][:n_parents]
        parents = [population[i] for i in parent_indices]

        new_population = [parents[0]]  # elitism
        while len(new_population) < pop_size:
            parent = random.choice(parents)
            child = mutate_sequence_with_mask(parent, mut_rate, mutable_mask)
            new_population.append(child)

        population = new_population

    return {
        "exp_id": exp_id,
        "seed": seed,
        "history": history,
        "best_reward": best_reward,
        "best_plddt": best_plddt,
        "best_seq": best_seq,
        "final_population": population,
    }


##########################################
# 5. Set up constrained loop regime
##########################################

# Choose a scaffold from the base LM (fixed seed for reproducibility)
random.seed(777)
np.random.seed(777)
torch.manual_seed(777)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(777)

SCAFFOLD_LENGTH = 60
scaffold_seq = sample_sequence_from_lm(
    lm_model,
    alphabet,
    length=SCAFFOLD_LENGTH,
    temperature=1.0,
    device=device,
)

# Define central mutable loop: e.g., positions 20..39 (0-based)
loop_start = 20
loop_end = 40  # exclusive
mutable_mask = [False] * SCAFFOLD_LENGTH
for i in range(loop_start, loop_end):
    mutable_mask[i] = True

print("\n=== Constrained loop regime ===")
print(f"Scaffold sequence (L={SCAFFOLD_LENGTH}):")
print(scaffold_seq)
print(f"Mutable region: positions {loop_start}..{loop_end-1} (0-based)")
print("Mutable mask (first 60 positions):")
print("".join("X" if m else "." for m in mutable_mask))


##########################################
# 6. Multi-seed ES vs SeqEvo comparison
##########################################

N_SEEDS = 3  # bump this up when you have more compute

# Match total training ESMFold calls per seed:
# ES: steps * pop_size * num_sequences
CONSTR_ES_STEPS = 5
CONSTR_ES_POP = 4
CONSTR_ES_NUM_SEQ = 2
ES_BUDGET = CONSTR_ES_STEPS * CONSTR_ES_POP * CONSTR_ES_NUM_SEQ

# SeqEvo: generations * population size
CONSTR_SEQ_POP = 8
CONSTR_SEQ_GENS = 5
SEQ_BUDGET = CONSTR_SEQ_POP * CONSTR_SEQ_GENS

print("\n=== Constrained multi-seed CONFIG ===")
print(f"ES budget:    {ES_BUDGET} ESMFold calls (steps={CONSTR_ES_STEPS}, pop={CONSTR_ES_POP}, seq/worker={CONSTR_ES_NUM_SEQ})")
print(f"SeqEvo budget:{SEQ_BUDGET} ESMFold calls (gens={CONSTR_SEQ_GENS}, pop={CONSTR_SEQ_POP})")

# Composite objective weights
WEIGHT_LM_PRIOR = 0.1
TARGET_HYDRO = 0.50
LAMBDA_HYDRO = 20.0

es_results_all = []
seqevo_results_all = []
es_div_all = []
seq_div_all = []

reset_esmfold_counters()

base_seed = 3000

for i in range(N_SEEDS):
    seed = base_seed + i

    # --- ES ---
    es_res = run_es_multiobjective_experiment_constrained(
        exp_id=i,
        seed=seed,
        train_steps=CONSTR_ES_STEPS,
        train_pop_size=CONSTR_ES_POP,
        train_num_sequences=CONSTR_ES_NUM_SEQ,
        seq_length=SCAFFOLD_LENGTH,
        temperature=1.0,
        sigma=TRAIN_SIGMA,
        lr=TRAIN_LR,
        weight_lm_prior=WEIGHT_LM_PRIOR,
        target_hydro=TARGET_HYDRO,
        lambda_hydro=LAMBDA_HYDRO,
        base_seq=scaffold_seq,
        mutable_mask=mutable_mask,
    )
    es_results_all.append(es_res)

    # diversity of ES samples
    es_div = analyze_sequence_set(es_res["final_seqs"], label=f"ES-exp{i}")
    es_div_all.append(es_div)

    # --- SeqEvo ---
    seq_res = sequence_space_evolution_objective_constrained(
        exp_id=i,
        seed=seed + 10_000,
        pop_size=CONSTR_SEQ_POP,
        n_generations=CONSTR_SEQ_GENS,
        n_parents=max(4, CONSTR_SEQ_POP // 4),
        seq_length=SCAFFOLD_LENGTH,
        temperature=1.0,
        mut_rate=0.05,
        weight_lm_prior=WEIGHT_LM_PRIOR,
        target_hydro=TARGET_HYDRO,
        lambda_hydro=LAMBDA_HYDRO,
        base_seq=scaffold_seq,
        mutable_mask=mutable_mask,
    )
    seqevo_results_all.append(seq_res)

    seq_div = analyze_sequence_set(seq_res["final_population"], label=f"SeqEvo-exp{i}")
    seq_div_all.append(seq_div)

print("\n=== Training ESMFold compute summary (constrained runs) ===")
print(f"ES ESMFold calls used:      {ESMFOLD_CALLS_ES}")
print(f"SeqEvo ESMFold calls used:  {ESMFOLD_CALLS_SEQEVO}")


##########################################
# 7. Aggregate stats across seeds
##########################################

def summarize_best_rewards(es_results_all, seqevo_results_all):
    es_best = [r["best_training_reward"] for r in es_results_all]
    se_best = [r["best_reward"] for r in seqevo_results_all]

    print("\n=== Best training reward across seeds ===")
    print(f"ES:    mean={np.mean(es_best):.2f} ± {np.std(es_best):.2f}")
    print(f"SeqEV: mean={np.mean(se_best):.2f} ± {np.std(se_best):.2f}")
    return es_best, se_best


def summarize_diversity(div_list, label):
    ids = [d["mean_identity"] for d in div_list if not math.isnan(d["mean_identity"])]
    ents = [d["mean_entropy"] for d in div_list]
    print(f"\n=== Diversity summary: {label} ===")
    print(f"Mean pairwise identity: {np.mean(ids):.3f} ± {np.std(ids):.3f}")
    print(f"Mean per-pos entropy:   {np.mean(ents):.3f} ± {np.std(ents):.3f}")
    return ids, ents


es_best, se_best = summarize_best_rewards(es_results_all, seqevo_results_all)
es_ids, es_ents = summarize_diversity(es_div_all, "ES")
se_ids, se_ents = summarize_diversity(seq_div_all, "SeqEvo")


##########################################
# 8. Simple bar plots: reward vs diversity
##########################################

import matplotlib.pyplot as plt

x = np.arange(N_SEEDS)

plt.figure(figsize=(7,4))
plt.bar(x - 0.15, es_best, width=0.3, label="ES best reward")
plt.bar(x + 0.15, se_best, width=0.3, label="SeqEvo best reward")
plt.xlabel("Seed")
plt.ylabel("Best training reward")
plt.title("Best composite reward per seed (constrained loop regime)")
plt.legend()
plt.grid(True, axis="y")
plt.show()

plt.figure(figsize=(7,4))
plt.bar([0-0.15, 0+0.15],
        [np.mean(es_ids), np.mean(se_ids)],
        yerr=[np.std(es_ids), np.std(se_ids)],
        width=0.3,
        tick_label=["ES", "SeqEvo"])
plt.ylabel("Mean pairwise identity")
plt.title("Diversity: mean pairwise identity across seeds")
plt.grid(True, axis="y")
plt.show()

plt.figure(figsize=(7,4))
plt.bar([0-0.15, 0+0.15],
        [np.mean(es_ents), np.mean(se_ents)],
        yerr=[np.std(es_ents), np.std(se_ents)],
        width=0.3,
        tick_label=["ES", "SeqEvo"])
plt.ylabel("Mean per-position entropy (bits)")
plt.title("Diversity: mean entropy across seeds")
plt.grid(True, axis="y")
plt.show()

In [None]:
%%time
"""
Publication-oriented constrained loop benchmark:
ES (EGGROLL-style) vs sequence-space evolution, multi-seed / multi-loop,
with a stronger multi-objective reward.

This cell depends on:
- lm_model, alphabet, fold_model, device
- AA_LETTERS, cls_idx, eos_idx
- evaluate_sequence_with_esmfold(sequence) -> mean_pLDDT
- egg_ctx, linear_weights, es_step_eggroll (EGGROLL-style ES)
"""

import math
import random
from collections import defaultdict

import numpy as np
import torch
import matplotlib.pyplot as plt

# -----------------------------
# 0. GLOBAL CONFIG
# -----------------------------

# Two loop regimes on the same scaffold (you can add more)
SCAFFOLD_SEQ = "KSGISCHGGIWIASFGKHKKRCKAKYERQYVRLIYKNKDKKFSTIKGLWKMIEAEYPDKI"
assert len(SCAFFOLD_SEQ) == 60

CONSTRAINED_SETTINGS = [
    {
        "name": "loop20_39",
        "scaffold": SCAFFOLD_SEQ,
        "mut_start": 20,   # inclusive
        "mut_end": 40,     # exclusive
    },
    {
        "name": "loop10_29",
        "scaffold": SCAFFOLD_SEQ,
        "mut_start": 10,
        "mut_end": 30,
    },
]

# Multi-objective weights
LM_NLL_SCALE = 10.0        # penalty on LM negative log-likelihood per residue
HYDRO_TARGET = 0.45        # target hydrophobic fraction in the loop
HYDRO_SCALE = 80.0
POS_CHARGE_TARGET = 0.35   # K,R,H fraction
NEG_CHARGE_TARGET = 0.35   # D,E fraction
CHARGE_SCALE = 60.0

# ES config (per experiment)
ES_TOTAL_EVALS = 40        # total ESMFold evals per method per seed
ES_POP_SIZE = 4
ES_NUM_SEQ_PER_MEMBER = 2
ES_NUM_STEPS = ES_TOTAL_EVALS // (ES_POP_SIZE * ES_NUM_SEQ_PER_MEMBER)
ES_RANK = 4
ES_SIGMA = 0.02
ES_LR = 0.03
ES_TEMP = 1.0

# Sequence evolution config (per experiment)
SEQEVO_TOTAL_EVALS = ES_TOTAL_EVALS  # match ES
SEQEVO_POP_SIZE = 8
SEQEVO_N_GENERATIONS = SEQEVO_TOTAL_EVALS // SEQEVO_POP_SIZE
SEQEVO_MUT_RATE = 0.05  # per-position in loop

# Which Linear weights to train in ES (subset of `linear_weights`)
NUM_LINEAR_TO_TRAIN = 24

# Seeds and settings
N_SEEDS = 3
BASE_SEED = 4000

print("=== Constrained benchmark CONFIG ===")
print(f"SCAFFOLD_SEQ (L={len(SCAFFOLD_SEQ)}): {SCAFFOLD_SEQ}")
for cfg in CONSTRAINED_SETTINGS:
    print(
        f"  Setting {cfg['name']}: loop {cfg['mut_start']}..{cfg['mut_end']-1} "
        f"(length={cfg['mut_end']-cfg['mut_start']})"
    )
print(f"ES:     evals={ES_TOTAL_EVALS}, steps={ES_NUM_STEPS}, pop={ES_POP_SIZE}, seq/worker={ES_NUM_SEQ_PER_MEMBER}")
print(f"SeqEvo: evals={SEQEVO_TOTAL_EVALS}, gens={SEQEVO_N_GENERATIONS}, pop={SEQEVO_POP_SIZE}")
print(f"ES trains last {NUM_LINEAR_TO_TRAIN} Linear weights out of {len(linear_weights)} total.")

# Subset of Linear weights that ES will perturb
trained_linear_weights_pub = linear_weights[-NUM_LINEAR_TO_TRAIN:]


# -----------------------------
# 1. BASIC UTILITIES
# -----------------------------

AA_SET = set(AA_LETTERS)
HYDROPHOBIC = set("AILMFWYV")
POSITIVE = set("KRH")
NEGATIVE = set("DE")


def enforce_scaffold_constraint(scaffold: str, loop_seq: str, mut_start: int, mut_end: int) -> str:
    """Return full sequence = scaffold outside [mut_start:mut_end], loop_seq inside."""
    assert len(scaffold) == len(SCAFFOLD_SEQ), "All scaffolds assumed same length here."
    assert mut_end > mut_start
    assert len(loop_seq) == (mut_end - mut_start)
    return scaffold[:mut_start] + loop_seq + scaffold[mut_start + len(loop_seq):]


def random_loop_from_lm(loop_len: int, temperature: float = 1.0) -> str:
    """Sample a loop of given length from lm_model (independent of scaffold)."""
    from torch.nn import functional as F

    tokens = torch.full(
        (1, loop_len + 2),
        fill_value=alphabet.mask_idx,
        device=device,
        dtype=torch.long,
    )
    tokens[0, 0] = cls_idx
    tokens[0, -1] = eos_idx

    for pos in range(1, loop_len + 1):
        out = lm_model(tokens, repr_layers=[], return_contacts=False)
        logits = out["logits"][0, pos]
        aa_logits = logits[AA_INDICES]
        probs = F.softmax(aa_logits / temperature, dim=-1)
        aa_idx = torch.multinomial(probs, num_samples=1)
        tok_id = AA_INDICES[aa_idx]
        tokens[0, pos] = tok_id

    seq_tokens = tokens[0, 1:-1].tolist()
    seq = "".join(alphabet.get_tok(t) for t in seq_tokens)
    return seq


def mutate_loop(seq: str, mut_start: int, mut_end: int, mut_rate: float) -> str:
    """Randomly mutate positions within [mut_start:mut_end) with probability mut_rate."""
    L = len(seq)
    loop_len = mut_end - mut_start
    seq_list = list(seq)
    for i in range(loop_len):
        global_pos = mut_start + i
        if random.random() < mut_rate:
            orig = seq_list[global_pos]
            choices = [aa for aa in AA_LETTERS if aa != orig]
            seq_list[global_pos] = random.choice(choices)
    return "".join(seq_list)


def sequence_identity(s1: str, s2: str) -> float:
    assert len(s1) == len(s2)
    return sum(a == b for a, b in zip(s1, s2)) / len(s1)


def mean_pairwise_identity(seqs):
    if len(seqs) < 2:
        return 1.0
    ids = []
    for i in range(len(seqs)):
        for j in range(i + 1, len(seqs)):
            ids.append(sequence_identity(seqs[i], seqs[j]))
    return float(np.mean(ids))


def mean_position_entropy(seqs):
    """Average Shannon entropy (bits) over positions for the set of sequences."""
    if len(seqs) == 0:
        return 0.0
    L = len(seqs[0])
    for s in seqs:
        assert len(s) == L
    entropies = []
    for pos in range(L):
        counts = defaultdict(int)
        for s in seqs:
            counts[s[pos]] += 1
        freqs = np.array([c / len(seqs) for c in counts.values()], dtype=float)
        H = -(freqs * np.log2(freqs + 1e-12)).sum()
        entropies.append(H)
    return float(np.mean(entropies))


# -----------------------------
# 2. LM PRIOR & LOOP PENALTIES
# -----------------------------

@torch.no_grad()
def lm_sequence_avg_nll(seq: str) -> float:
    """
    Approximate negative log-likelihood per residue using ESM-2.
    We treat the logits of the unmasked forward as pseudo-conditional probabilities.
    """
    token_ids = [alphabet.get_idx(a) for a in seq]
    tokens = torch.tensor(
        [[cls_idx] + token_ids + [eos_idx]],
        device=device,
        dtype=torch.long,
    )
    out = lm_model(tokens, repr_layers=[], return_contacts=False)
    logits = out["logits"][0, 1:-1]  # (L, vocab)
    log_probs = torch.log_softmax(logits, dim=-1)
    idxs = torch.tensor(token_ids, device=device, dtype=torch.long)
    token_logp = log_probs[torch.arange(len(seq), device=device), idxs]
    avg_nll = float(-token_logp.mean().item())
    return avg_nll


def loop_biophysics_penalty(seq: str, mut_start: int, mut_end: int):
    loop = seq[mut_start:mut_end]
    L = len(loop)
    hyd = sum(aa in HYDROPHOBIC for aa in loop) / L
    pos = sum(aa in POSITIVE for aa in loop) / L
    neg = sum(aa in NEGATIVE for aa in loop) / L

    hyd_pen = HYDRO_SCALE * max(0.0, hyd - HYDRO_TARGET) ** 2
    pos_pen = CHARGE_SCALE * max(0.0, pos - POS_CHARGE_TARGET) ** 2
    neg_pen = CHARGE_SCALE * max(0.0, neg - NEG_CHARGE_TARGET) ** 2
    total_pen = hyd_pen + pos_pen + neg_pen
    comps = {
        "hydrophobic_frac": hyd,
        "pos_frac": pos,
        "neg_frac": neg,
        "hyd_pen": hyd_pen,
        "pos_pen": pos_pen,
        "neg_pen": neg_pen,
    }
    return total_pen, comps


# -----------------------------
# 3. MULTI-OBJECTIVE REWARD
# -----------------------------

@torch.no_grad()
def evaluate_sequence_multiobjective(
    full_seq: str,
    mut_start: int,
    mut_end: int,
):
    """
    Multi-objective reward used in the constrained benchmark.
    Components:
      - + mean pLDDT (ESMFold)
      - - LM_NLL_SCALE * avg NLL (ESM-2 pseudo-likelihood)
      - - loop biophysics penalty (hydrophobicity / charge)
    """
    # Structural confidence
    mean_plddt = evaluate_sequence_with_esmfold(full_seq)

    # LM prior
    avg_nll = lm_sequence_avg_nll(full_seq)
    lm_penalty = LM_NLL_SCALE * avg_nll

    # Simple biophysics
    loop_penalty, loop_comps = loop_biophysics_penalty(full_seq, mut_start, mut_end)

    reward = mean_plddt - lm_penalty - loop_penalty

    comps = {
        "plddt": mean_plddt,
        "avg_nll": avg_nll,
        "lm_penalty": lm_penalty,
        "loop_penalty": loop_penalty,
    }
    comps.update(loop_comps)
    return reward, comps


# -----------------------------
# 4. ES EXPERIMENT (per setting, per seed)
# -----------------------------

def es_constrained_eval_once(
    scaffold_cfg,
    num_sequences: int,
    temperature: float,
):
    """
    For a given ES parameter vector (implicitly through egg_ctx),
    sample num_sequences loops from lm_model, enforce scaffold,
    and return mean composite reward.
    """
    mut_start = scaffold_cfg["mut_start"]
    mut_end = scaffold_cfg["mut_end"]
    loop_len = mut_end - mut_start
    scaffold = scaffold_cfg["scaffold"]

    rewards = []
    for _ in range(num_sequences):
        loop_seq = random_loop_from_lm(loop_len, temperature=temperature)
        full_seq = enforce_scaffold_constraint(scaffold, loop_seq, mut_start, mut_end)
        r, _ = evaluate_sequence_multiobjective(full_seq, mut_start, mut_end)
        rewards.append(r)
    return float(np.mean(rewards)), rewards


def run_es_constrained_single(
    scaffold_cfg,
    seed: int,
):
    """
    Run one ES experiment for a single scaffold+loop setting.
    Returns dict with:
      - 'history' (avg, max per step)
      - 'best_reward'
      - 'best_sequences'
      - 'final_samples' (for diversity)
    """
    print("\n==============================")
    print(f"ES constrained run: setting={scaffold_cfg['name']} seed={seed}")
    print("==============================")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    history = []
    best_reward = -1e9
    best_sequences = []

    def eval_fn(num_sequences, seq_length_unused, temperature):
        avg_r, rewards = es_constrained_eval_once(
            scaffold_cfg,
            num_sequences=num_sequences,
            temperature=temperature,
        )
        return avg_r, rewards

    # ES steps
    for step in range(1, ES_NUM_STEPS + 1):
        avg_r, max_r, all_r = es_step_eggroll(
            lm_model,
            trained_linear_weights_pub,
            rank=ES_RANK,
            sigma=ES_SIGMA,
            pop_size=ES_POP_SIZE,
            lr=ES_LR,
            device=device,
            eval_fn=eval_fn,
            num_sequences=ES_NUM_SEQ_PER_MEMBER,
            seq_length=len(scaffold_cfg["scaffold"]),
            temperature=ES_TEMP,
        )
        history.append((avg_r, max_r))
        print(f"[ES {scaffold_cfg['name']} | step {step:02d}] avg={avg_r:.2f}, max={max_r:.2f}")

        if max_r > best_reward:
            best_reward = max_r

    # Sample a batch from final ES-updated model for diversity analysis
    mut_start = scaffold_cfg["mut_start"]
    mut_end = scaffold_cfg["mut_end"]
    loop_len = mut_end - mut_start
    scaffold = scaffold_cfg["scaffold"]
    final_samples = []
    with torch.no_grad():
        for _ in range(ES_POP_SIZE * ES_NUM_SEQ_PER_MEMBER * 2):
            loop_seq = random_loop_from_lm(loop_len, temperature=ES_TEMP)
            full_seq = enforce_scaffold_constraint(scaffold, loop_seq, mut_start, mut_end)
            final_samples.append(full_seq)

    div_identity = mean_pairwise_identity(final_samples)
    div_entropy = mean_position_entropy(final_samples)
    print(
        f"[ES diversity] n={len(final_samples)}, "
        f"mean identity={div_identity:.3f}, mean entropy={div_entropy:.3f}"
    )

    return {
        "history": history,
        "best_reward": best_reward,
        "final_samples": final_samples,
        "div_identity": div_identity,
        "div_entropy": div_entropy,
    }


# -----------------------------
# 5. SEQUENCE EVOLUTION EXPERIMENT
# -----------------------------

def run_seqevo_constrained_single(
    scaffold_cfg,
    seed: int,
):
    """
    Simple sequence-space evolution baseline under same scaffold+loop constraint.
    Mutations restricted to the loop.
    """
    print("\n=== SeqEvo constrained run: setting={} seed={} ===".format(scaffold_cfg["name"], seed))

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    mut_start = scaffold_cfg["mut_start"]
    mut_end = scaffold_cfg["mut_end"]
    loop_len = mut_end - mut_start
    scaffold = scaffold_cfg["scaffold"]

    # Initialize population by sampling random loops from LM
    population = []
    for _ in range(SEQEVO_POP_SIZE):
        loop_seq = random_loop_from_lm(loop_len, temperature=ES_TEMP)
        full_seq = enforce_scaffold_constraint(scaffold, loop_seq, mut_start, mut_end)
        population.append(full_seq)

    history = []
    best_reward = -1e9
    best_seq = None

    for gen in range(SEQEVO_N_GENERATIONS):
        rewards = []
        pl_ddts = []
        for seq in population:
            r, comps = evaluate_sequence_multiobjective(seq, mut_start, mut_end)
            rewards.append(r)
            pl_ddts.append(comps["plddt"])
        rewards = np.array(rewards, dtype=float)
        history.append((float(rewards.mean()), float(rewards.max())))
        print(
            f"[SeqEvo {scaffold_cfg['name']} | gen {gen:02d}] "
            f"avg={rewards.mean():.2f}, max={rewards.max():.2f}, avg_pLDDT={np.mean(pl_ddts):.2f}"
        )

        # Track best
        max_idx = int(rewards.argmax())
        if rewards[max_idx] > best_reward:
            best_reward = float(rewards[max_idx])
            best_seq = population[max_idx]

        # Selection + mutation
        parent_indices = rewards.argsort()[::-1][: max(2, SEQEVO_POP_SIZE // 4)]
        parents = [population[i] for i in parent_indices]

        new_population = []
        while len(new_population) < SEQEVO_POP_SIZE:
            parent = random.choice(parents)
            child = mutate_loop(parent, mut_start, mut_end, SEQEVO_MUT_RATE)
            new_population.append(child)
        population = new_population

    # Diversity on final population
    div_identity = mean_pairwise_identity(population)
    div_entropy = mean_position_entropy(population)
    print(
        f"[SeqEvo diversity] n={len(population)}, "
        f"mean identity={div_identity:.3f}, mean entropy={div_entropy:.3f}"
    )

    return {
        "history": history,
        "best_reward": best_reward,
        "best_seq": best_seq,
        "final_population": population,
        "div_identity": div_identity,
        "div_entropy": div_entropy,
    }


# -----------------------------
# 6. FULL BENCHMARK OVER LOOPS & SEEDS
# -----------------------------

def run_full_constrained_benchmark():
    """
    Run ES and SeqEvo for each constrained setting and seed.
    Aggregate performance and diversity.
    """
    all_results = {
        "ES": defaultdict(list),      # setting -> list of per-seed dicts
        "SeqEvo": defaultdict(list),
    }

    for cfg in CONSTRAINED_SETTINGS:
        for s in range(N_SEEDS):
            es_seed = BASE_SEED + s
            seq_seed = BASE_SEED + 1000 + s

            es_res = run_es_constrained_single(cfg, es_seed)
            se_res = run_seqevo_constrained_single(cfg, seq_seed)

            all_results["ES"][cfg["name"]].append(es_res)
            all_results["SeqEvo"][cfg["name"]].append(se_res)

    # Aggregate statistics and plot
    for cfg in CONSTRAINED_SETTINGS:
        name = cfg["name"]
        es_runs = all_results["ES"][name]
        se_runs = all_results["SeqEvo"][name]

        es_best = np.array([r["best_reward"] for r in es_runs], dtype=float)
        se_best = np.array([r["best_reward"] for r in se_runs], dtype=float)

        es_id = np.array([r["div_identity"] for r in es_runs], dtype=float)
        se_id = np.array([r["div_identity"] for r in se_runs], dtype=float)
        es_ent = np.array([r["div_entropy"] for r in es_runs], dtype=float)
        se_ent = np.array([r["div_entropy"] for r in se_runs], dtype=float)

        print("\n=== Summary for setting:", name, "===")
        print(
            f"Best reward ES:     {es_best.mean():.2f} ± {es_best.std():.2f} "
            f"(n={len(es_best)})"
        )
        print(
            f"Best reward SeqEvo: {se_best.mean():.2f} ± {se_best.std():.2f} "
            f"(n={len(se_best)})"
        )
        print(
            f"Diversity (identity) ES: {es_id.mean():.3f} ± {es_id.std():.3f}; "
            f"SeqEvo: {se_id.mean():.3f} ± {se_id.std():.3f}"
        )
        print(
            f"Diversity (entropy)  ES: {es_ent.mean():.3f} ± {es_ent.std():.3f}; "
            f"SeqEvo: {se_ent.mean():.3f} ± {se_ent.std():.3f}"
        )

        # Plot best reward per seed
        seeds = np.arange(len(es_best))
        width = 0.35
        plt.figure(figsize=(5, 3.5))
        plt.bar(seeds - width / 2, es_best, width, label="ES best")
        plt.bar(seeds + width / 2, se_best, width, label="SeqEvo best")
        plt.xlabel("Seed")
        plt.ylabel("Best composite reward")
        plt.title(f"Best reward per seed ({name})")
        plt.legend()
        plt.tight_layout()
        plt.show()

        # Diversity plots
        plt.figure(figsize=(4, 3.5))
        plt.bar([0, 1], [es_id.mean(), se_id.mean()], yerr=[es_id.std(), se_id.std()])
        plt.xticks([0, 1], ["ES", "SeqEvo"])
        plt.ylabel("Mean pairwise identity")
        plt.title(f"Diversity: identity ({name})")
        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(4, 3.5))
        plt.bar([0, 1], [es_ent.mean(), se_ent.mean()], yerr=[es_ent.std(), se_ent.std()])
        plt.xticks([0, 1], ["ES", "SeqEvo"])
        plt.ylabel("Mean per-position entropy (bits)")
        plt.title(f"Diversity: entropy ({name})")
        plt.tight_layout()
        plt.show()

    return all_results


# -----------------------------
# 7. RUN THE BENCHMARK
# -----------------------------

full_results = run_full_constrained_benchmark()
print("\nBenchmark finished.")