# 04 — Surrogate-guided closed-loop (DL / active learning)

In [None]:
!pip -q install transformers accelerate biopython pandas numpy matplotlib tqdm scikit-learn pyyaml
import torch, platform
print("torch", torch.__version__, "cuda?", torch.cuda.is_available(), "python", platform.python_version())


In [None]:
!git clone -q https://github.com/dauparas/ProteinMPNN.git
!pip -q install -r ProteinMPNN/requirements.txt


In [None]:
from pathlib import Path
import pandas as pd
import torch
from src.data.scaffolds import load_scaffold
from src.generate.proteinmpnn import run_proteinmpnn, read_fasta_sequences
from src.evaluate.esmfold_eval import evaluate_batch
from src.generate.mutations import make_mutant_pool
from src.models.surrogate import SurrogateConfig, get_embeddings, train_surrogate, predict_surrogate

OUT = Path("results")
sc = load_scaffold("1AKL","A", OUT/"scaffolds")

# labeled set
fasta = run_proteinmpnn(sc.pdb_path, OUT/"mpnn_surrogate_r0", Path("ProteinMPNN"), num_seqs=60, sampling_temp=0.2, seed=42)
seqs0 = read_fasta_sequences(fasta)[:40]
lab = evaluate_batch(seqs0, model_id="facebook/esmfold_v1", device="cuda", out_dir=OUT/"pdb"/"surrogate_r0", max_n=20)

train_seqs = [r.sequence for r in lab]
y = torch.tensor([r.mean_plddt for r in lab], dtype=torch.float32)

cfg = SurrogateConfig(esm2_model_id="facebook/esm2_t12_35M_UR50D", pool="mean", epochs=8, lr=1e-3, device="cuda")
X = get_embeddings(train_seqs, cfg.esm2_model_id, cfg.pool, cfg.device)
m = train_surrogate(X, y, cfg)

# propose many, fold few
cands = make_mutant_pool(train_seqs, n=200, rate=0.03, seed=123)
Xc = get_embeddings(cands, cfg.esm2_model_id, cfg.pool, cfg.device)
pred = predict_surrogate(m, Xc, cfg.device).numpy()
top_idx = pred.argsort()[::-1][:30]
to_fold = [cands[i] for i in top_idx]
res = evaluate_batch(to_fold, model_id="facebook/esmfold_v1", device="cuda", out_dir=OUT/"pdb"/"surrogate_r1", max_n=20)

df = pd.DataFrame([{"sequence":r.sequence, "mean_plddt":r.mean_plddt} for r in res]).sort_values("mean_plddt", ascending=False)
df.head(10)


In [None]:
from pathlib import Path
OUT = Path("results")
(OUT/"tables").mkdir(parents=True, exist_ok=True)
df.to_csv(OUT/"tables"/"surrogate_guided.csv", index=False)
print("saved", OUT/"tables"/"surrogate_guided.csv")
