
### Residue–Residue Contact Prediction (ESM2-extended)

**Purpose:** a lighter, time-bounded first pass you can run locally to verify end‑to‑end training & testing.

### What's included
- Load PDBs from `../data/pdb/train` and `../data/pdb/test`.
- **Cap to the first 100 train structures** (randomly sampled; change `MAX_TRAIN_FILES`) to keep the first run quick.
- Train/val split = **80/20** by PDB file.
- Build **binary contact labels** from ATOM Cα distances (< 8 Å).
- Embed **ATOM-derived sequences** per chain using a **tiny ESM2** (`facebook/esm2_t6_8M_UR50D`) if available; otherwise a local fallback embedder.
- Small hyperparameter sweep (lr, hidden, dropout) → pick best, retrain on full train, save model.
- Evaluate on held-out **test** set: **PR-AUC, ROC-AUC, Precision@L/L2/L5**.
- Notes on multi-chain handling & robustness (unknown residues, empty chains).

> ** This runs fine on CPU. If you want extra speed and have a recent PyTorch, it will auto‑use **MPS** if available.


In [1]:

# -------------------------
# Config (edit as needed)
# -------------------------
from pathlib import Path
import random, warnings
import numpy as np
import torch
import os
from typing import Dict, List, Tuple, Optional
from Bio.PDB import PDBParser, MMCIFParser, PPBuilder
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, roc_auc_score
from tqdm import tqdm
import pandas as pd

# Folders (relative to this notebook)
PDB_TRAIN_DIR = Path("../data/pdb/train")    # training PDBs here
PDB_TEST_DIR  = Path("../data/pdb/test")     # test PDBs here

# Model + training settings
MODEL_ID = "facebook/esm2_t6_8M_UR50D"       # tiny ESM2 (downloads once to HF cache)
CA_DIST_THRESH = 8.0                         # Å for contact definition
VAL_SPLIT = 0.2                              # 80/20 split (by PDB file) inside train folder
SEED = 42

# ---- FastStart caps ----
MAX_TRAIN_FILES = 100                        # cap the number of train PDBs to speed up the first run
PAIR_SUBSAMPLE_TRAIN = 50_000                # cap #pairs per structure for training (speed/memory)
BATCH_PAIRS = 25_000                         # pairs per optimization step
EPOCHS = 2                                   # keep tiny for the first run

SAVE_PATH = Path("models/rescontact_best.pt") # final model path

# Hyperparameter search space (fast run)
HPARAM_GRID = {
    "lr":       [1e-3],
    "hidden":   [128],
    "dropout":  [0.1],
}

# Reproducibility
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Device (MPS if available)
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

# Silence harmless Biopython warnings about modified residues (e.g., 'GEK')
warnings.filterwarnings("ignore", message="Assuming residue .* is an unknown modified amino acid")


Using device: mps



#### Parsing structures → sequences & contact labels

- **Sequences**: ATOM-derived (concatenated polypeptide segments per chain). Unknown/modified residues appear as `X`.
- **Contacts**: Cα–Cα distance < 8 Å; pairs must have coordinates for both residues.
- **Multi-chain**: per-chain blocks assembled into a block-diagonal map; inter-chain pairs ignored in this fast pass.


In [2]:

def _parser_for(path: Path):
    ext = path.suffix.lower()
    if ext == ".pdb":
        return PDBParser(QUIET=True)
    if ext in (".cif", ".mmcif"):
        return MMCIFParser(QUIET=True)
    raise ValueError(f"Unsupported structure format: {path}")

def load_structure(path: Path):
    parser = _parser_for(path)
    struct = parser.get_structure(path.stem, str(path))
    return struct

def extract_atom_seq_by_chain(struct):
    """Return ATOM-derived sequences per chain (concatenate polypeptides)."""
    seqs = {}
    ppb = PPBuilder()  # ATOM-based builder
    model = next(iter(struct))  # first model only
    for chain in model:
        polypeps = list(ppb.build_peptides(chain, aa_only=False))
        if not polypeps:
            continue
        seq = "".join([str(pp.get_sequence()) for pp in polypeps])
        if seq:
            clean = [(ch if ch in "ACDEFGHIKLMNPQRSTVWY" else "X") for ch in seq]
            seqs[chain.id] = "".join(clean)
    return seqs

def extract_ca_coords_by_chain(struct):
    ppb = PPBuilder()
    chain_coords = {}
    for model in struct:
        for chain in model:
            polypeps = ppb.build_peptides(chain, aa_only=False)
            if not polypeps:
                continue
            coords = []
            offset = 0
            for pp in polypeps:
                for i, res in enumerate(pp):
                    ca = res["CA"] if "CA" in res else None
                    if ca is not None:
                        coords.append((offset + i, ca.coord.copy()))
                offset += len(pp)
            if coords:
                chain_coords[chain.id] = coords
        break
    return chain_coords

def contact_map_from_coords(coords, L: int, thresh: float):
    import numpy as np
    has = np.zeros(L, dtype=bool)
    xyz = np.zeros((L, 3), dtype=np.float32)
    for i, c in coords:
        has[i] = True
        xyz[i] = c
    idx = np.where(has)[0]
    contact = np.zeros((L, L), dtype=bool)
    if len(idx) > 0:
        sub = xyz[idx]
        d = np.sqrt(((sub[:,None,:] - sub[None,:,:])**2).sum(-1))
        c = d < thresh
        for a, ia in enumerate(idx):
            for b, ib in enumerate(idx):
                contact[ia, ib] = c[a, b]
    return contact, has



#### Embedding sequences (tiny ESM2 or fallback)


In [4]:

class FallbackEmbedder(torch.nn.Module):
    def __init__(self, dim=64, vocab=26):
        super().__init__()
        self.emb = torch.nn.Embedding(vocab, dim)
        torch.nn.init.xavier_uniform_(self.emb.weight)
    def forward(self, seq: str) -> torch.Tensor:
        idx = torch.tensor([(ord(ch) % 26) for ch in seq], dtype=torch.long, device=DEVICE)
        return self.emb(idx)

def try_load_esm2(model_id: str):
    try:
        from transformers import AutoModel, AutoTokenizer
        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, local_files_only=False)
        mdl = AutoModel.from_pretrained(model_id, trust_remote_code=True, local_files_only=False)
        mdl.to(DEVICE).eval()
        return tok, mdl
    except Exception as e:
        print(f"[warn] Could not load {model_id}: {e}\nUsing FallbackEmbedder.")
        return None, FallbackEmbedder().to(DEVICE).eval()

@torch.no_grad()
def embed_sequence(seq: str, tokenizer, model) -> torch.Tensor:
    if len(seq) == 0:
        return torch.zeros(0, 64, device=DEVICE)
    if tokenizer is None:
        return model(seq)
    tokens = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
    tokens = {k: v.to(DEVICE) for k, v in tokens.items()}
    out = model(**tokens)
    hidden = out.last_hidden_state[0]
    if hidden.shape[0] >= len(seq) + 2:
        hidden = hidden[1:1+len(seq)]
    return hidden.detach()


#### Pairwise MLP contact head

In [7]:

def build_pair_batch(H: torch.Tensor, pair_idx: np.ndarray) -> torch.Tensor:
    if pair_idx.size == 0:
        return torch.empty(0, device=DEVICE)
    hi = H[pair_idx[:,0]]
    hj = H[pair_idx[:,1]]
    feats = torch.cat([hi, hj, torch.abs(hi-hj), hi*hj], dim=-1)
    return feats

class PairMLP(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int = 128, dropout: float = 0.1):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(d_in, hidden),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)


#### Process structures → tensors

In [8]:

def collect_structs(folder: Path):
    pats = ("*.pdb","*.PDB","*.cif","*.CIF","*.mmcif","*.MMCIF")
    files = []
    for pat in pats:
        files += list(folder.glob(pat))
    return sorted(files)

def process_structure(path: Path, tokenizer, esm2_model):
    try:
        s = load_structure(path)
        chain_seqs = extract_atom_seq_by_chain(s)
        chain_coords = extract_ca_coords_by_chain(s)
        if not chain_seqs:
            return None

        seqs, chain_ids = [], []
        contacts_blocks, masks_blocks, H_blocks = [], [], []

        for cid in sorted(chain_seqs.keys()):
            seq = chain_seqs[cid]
            L = len(seq)
            if L == 0:
                continue
            coords = chain_coords.get(cid, [])
            cmat, has = contact_map_from_coords(coords, L, CA_DIST_THRESH)

            valid = np.zeros((L, L), dtype=bool)
            idx = np.where(has)[0]
            for i in idx:
                for j in idx:
                    if i != j:
                        valid[i, j] = True

            H = embed_sequence(seq, tokenizer, esm2_model)
            if H.shape[0] != L:
                continue

            H_blocks.append(H)
            seqs.append(seq)
            chain_ids.extend([cid] * L)
            contacts_blocks.append(cmat)
            masks_blocks.append(valid)

        if not seqs:
            return None

        Ls = [len(sq) for sq in seqs]
        Ltot = sum(Ls)
        contact_full = np.zeros((Ltot, Ltot), dtype=bool)
        valid_full   = np.zeros((Ltot, Ltot), dtype=bool)

        ci = 0
        for k, L in enumerate(Ls):
            cj = ci + L
            contact_full[ci:cj, ci:cj] = contacts_blocks[k]
            valid_full[ci:cj, ci:cj]   = masks_blocks[k]
            ci = cj

        H_full = torch.cat(H_blocks, dim=0).to(DEVICE)
        uniq = {c:i for i,c in enumerate(sorted(set(chain_ids)))}
        chain_ids_arr = np.array([uniq[c] for c in chain_ids], dtype=np.int64)
        return {
            "seq": "".join(seqs),
            "chain_ids": chain_ids_arr,
            "contact": contact_full,
            "valid_pair": valid_full,
            "H": H_full,
            "pdb_id": path.stem,
            "path": str(path),
        }
    except Exception as e:
        print(f"[error] {path.name}: {e}")
        return None


#### Training / Evaluation helpers

In [9]:

def sample_pairs(valid_mask: np.ndarray, max_pairs: Optional[int]) -> np.ndarray:
    idx = np.argwhere(valid_mask)
    if idx.size == 0:
        return idx
    if (max_pairs is not None) and (len(idx) > max_pairs):
        sel = np.random.choice(len(idx), size=max_pairs, replace=False)
        idx = idx[sel]
    return idx

def train_one_epoch(model, opt, train_structs, pair_cap, batch_pairs=BATCH_PAIRS):
    model.train()
    tot_loss = 0.0
    bce = torch.nn.BCEWithLogitsLoss()
    for S in train_structs:
        if S is None:
            continue
        H = S["H"]
        contact = torch.from_numpy(S["contact"]).to(DEVICE)
        pairs = sample_pairs(S["valid_pair"], pair_cap)
        if pairs.size == 0:
            continue
        for start in range(0, len(pairs), batch_pairs):
            sl = pairs[start:start+batch_pairs]
            X = build_pair_batch(H, sl)
            y = contact[sl[:,0], sl[:,1]].float()
            if X.numel() == 0:
                continue
            logits = model(X)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            tot_loss += float(loss.detach().cpu())
    return tot_loss

@torch.no_grad()
def evaluate_structs(model, structs):
    model.eval()
    all_scores, all_labels = [], []
    per_pdb_metrics = {}
    for S in structs:
        if S is None:
            continue
        H = S["H"]
        contact = torch.from_numpy(S["contact"]).to(DEVICE)
        idx = np.argwhere(S["valid_pair"])
        if idx.size == 0:
            continue
        X = build_pair_batch(H, idx)
        logits = model(X)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        labels = contact[idx[:,0], idx[:,1]].detach().cpu().numpy().astype(int)

        try:
            ap = average_precision_score(labels, probs)
            roc = roc_auc_score(labels, probs)
        except Exception:
            ap, roc = float("nan"), float("nan")

        L = len(S["seq"])
        order = np.argsort(-probs)
        def prec_at(k):
            k = max(1, min(len(order), k))
            sel = order[:k]
            return labels[sel].mean()
        pL   = prec_at(L)
        pL2  = prec_at(max(1, L//2))
        pL5  = prec_at(max(1, L//5))

        per_pdb_metrics[S["pdb_id"]] = dict(pr_auc=ap, roc_auc=roc, p_at_L=pL, p_at_L2=pL2, p_at_L5=pL5)
        all_scores.append(probs); all_labels.append(labels)

    if not all_scores:
        return dict(global_pr_auc=float("nan"), global_roc_auc=float("nan"), per_pdb=per_pdb_metrics)

    scores = np.concatenate(all_scores)
    labels = np.concatenate(all_labels)
    try:
        g_ap = average_precision_score(labels, scores)
        g_roc = roc_auc_score(labels, scores)
    except Exception:
        g_ap, g_roc = float("nan"), float("nan")
    return dict(global_pr_auc=g_ap, global_roc_auc=g_roc, per_pdb=per_pdb_metrics)


#### Build training model: cap→split→sweep→final train→test

In [10]:

# Load embedder
tokenizer, esm2_model = try_load_esm2(MODEL_ID)
print(f"Embedder: {type(esm2_model).__name__}")

# Collect files
train_all_files = collect_structs(PDB_TRAIN_DIR)
test_files = collect_structs(PDB_TEST_DIR)
print(f"Found {len(train_all_files)} train files, {len(test_files)} test files.")

# Cap training files for the first run
if len(train_all_files) > MAX_TRAIN_FILES:
    train_all_files = random.sample(train_all_files, MAX_TRAIN_FILES)
    print(f"Capped training files to {len(train_all_files)} for a quick run.")

# Split into train/val
train_files, val_files = train_test_split(train_all_files, test_size=VAL_SPLIT, random_state=SEED)
print(f"Split: {len(train_files)} train / {len(val_files)} val")

def hyperparam_search(train_files, val_files):
    print("Processing training structures...")
    train_structs = []
    for p in tqdm(train_files, desc="embed/train"):
        S = process_structure(p, tokenizer, esm2_model)
        if S is not None:
            train_structs.append(S)

    print("Processing validation structures...")
    val_structs = []
    for p in tqdm(val_files, desc="embed/val"):
        S = process_structure(p, tokenizer, esm2_model)
        if S is not None:
            val_structs.append(S)

    if not train_structs:
        raise ValueError("No valid training structures found!")

    d_in = train_structs[0]["H"].shape[-1]*4
    best = None
    for lr in HPARAM_GRID["lr"]:
        for hidden in HPARAM_GRID["hidden"]:
            for dropout in HPARAM_GRID["dropout"]:
                model = PairMLP(d_in=d_in, hidden=hidden, dropout=dropout).to(DEVICE)
                opt = torch.optim.AdamW(model.parameters(), lr=lr)
                for ep in range(1, EPOCHS+1):
                    loss = train_one_epoch(model, opt, train_structs, pair_cap=PAIR_SUBSAMPLE_TRAIN)
                val_scores = evaluate_structs(model, val_structs)
                score = val_scores["global_pr_auc"]
                print(f"[HP] lr={lr}, hidden={hidden}, dropout={dropout} -> val PR-AUC={score:.4f}")
                if (best is None) or (score > best["score"]):
                    best = dict(lr=lr, hidden=hidden, dropout=dropout, score=score)
    return best

best_cfg = hyperparam_search(train_files, val_files)
print("Best config:", best_cfg)

# Retrain best on full (train + val)
full_files = train_files + val_files
print("Processing full training set...")
full_structs = []
for p in tqdm(full_files, desc="embed/full-train"):
    S = process_structure(p, tokenizer, esm2_model)
    if S is not None:
        full_structs.append(S)

d_in = full_structs[0]["H"].shape[-1]*4
final_model = PairMLP(d_in=d_in, hidden=best_cfg["hidden"], dropout=best_cfg["dropout"]).to(DEVICE)
opt = torch.optim.AdamW(final_model.parameters(), lr=best_cfg["lr"])

for ep in range(1, EPOCHS+1):
    loss = train_one_epoch(final_model, opt, full_structs, pair_cap=PAIR_SUBSAMPLE_TRAIN)
    print(f"[final] epoch {ep} loss {loss:.3f}")

# Save
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save({
    "model_state": final_model.state_dict(),
    "cfg": best_cfg,
    "model_class": "PairMLP",
    "d_in": d_in,
}, SAVE_PATH)
print(f"Saved: {SAVE_PATH}")

# Evaluate on test
print("\nEvaluating on test set...")
test_structs = []
for p in tqdm(test_files, desc="embed/test"):
    S = process_structure(p, tokenizer, esm2_model)
    if S is not None:
        test_structs.append(S)

test_scores = evaluate_structs(final_model, test_structs)
print("Test (global): PR-AUC={:.4f}, ROC-AUC={:.4f}".format(
    test_scores["global_pr_auc"], test_scores["global_roc_auc"]))

# Save per-PDB
df_test = pd.DataFrame.from_dict(test_scores["per_pdb"], orient="index").reset_index().rename(columns={"index":"pdb_id"})
df_test.sort_values("pr_auc", ascending=False, inplace=True)
results_path = Path("results/test_metrics.csv")
results_path.parent.mkdir(exist_ok=True, parents=True)
df_test.to_csv(results_path, index=False)
print(f"Saved test results to: {results_path}")


  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/miniconda3/envs/.env_res_contact/lib/python3.11/site-packages/torchvision/image.so
  warn(
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedder: EsmModel
Found 15000 train files, 500 test files.
Capped training files to 100 for a quick run.
Split: 80 train / 20 val
Processing training structures...


embed/train: 100%|██████████| 80/80 [00:45<00:00,  1.75it/s]


Processing validation structures...


embed/val: 100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


KeyboardInterrupt: 


#### Notes & Next steps
- Add inter-chain contacts and report intra vs inter separately.
- Add long-range only evaluation (e.g., |i−j| ≥ 24).
- Integrate template/retrieval channels from similar training structures (for the “use structural data from similar sequences” criterion).
- Track time spent per section in `roadmap.txt`.
