#### Residue–Residue Contact Prediction (ESM2-extended) 

Lightweight pipeline for quick end‑to‑end verification on a laptop. Caps file counts, crops long chains, subsamples residue pairs, and uses streaming evaluation.

**Metrics**: PR-AUC, ROC-AUC, Precision@L/L2/L5


In [1]:

from pathlib import Path
import random, os
import numpy as np
import torch
from typing import Dict, List, Tuple, Optional
from Bio.PDB import PDBParser, MMCIFParser, PPBuilder
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, roc_auc_score
from tqdm import tqdm
import pandas as pd

PDB_TRAIN_DIR = Path("../data/pdb/train")
PDB_TEST_DIR  = Path("../data/pdb/test")
MODEL_ID = "facebook/esm2_t6_8M_UR50D"
CA_DIST_THRESH = 8.0
VAL_SPLIT = 0.2
SEED = 42

MAX_TRAIN_FILES = 60
EPOCHS = 1
PAIR_SUBSAMPLE_TRAIN = 20_000
BATCH_PAIRS = 10_000
MAX_LEN_PER_CHAIN = 512
EVAL_MAX_PAIRS_PER_STRUCT = 200_000
EVAL_BATCH_PAIRS = 20_000

SAVE_PATH = Path("models/rescontact_best.pt"); SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print("Using device:", DEVICE)


Using device: mps


In [2]:

def _parser_for(path: Path):
    ext = path.suffix.lower()
    if ext == ".pdb": return PDBParser(QUIET=True)
    if ext in (".cif",".mmcif"): return MMCIFParser(QUIET=True)
    raise ValueError(f"Unsupported structure format: {path}")

def load_structure(path: Path):
    return _parser_for(path).get_structure(path.stem, str(path))

def extract_atom_seq_by_chain(struct):
    seqs = {}
    ppb = PPBuilder()
    model = next(iter(struct))
    for chain in model:
        polypeps = list(ppb.build_peptides(chain, aa_only=False))
        if not polypeps: continue
        parts = []
        for pp in polypeps:
            s = str(pp.get_sequence())
            if s: parts.append(s)
        if parts: seqs[chain.id] = "".join(parts)
    return seqs

def extract_ca_coords_by_chain(struct):
    ppb = PPBuilder()
    chain_coords = {}
    model = next(iter(struct))
    for chain in model:
        polypeps = ppb.build_peptides(chain)
        if not polypeps: continue
        coords = []; offset = 0
        for pp in polypeps:
            for i, res in enumerate(pp):
                if "CA" in res:
                    coords.append((offset+i, res["CA"].coord.copy()))
            offset += len(pp)
        if coords: chain_coords[chain.id] = coords
    return chain_coords

def contact_map_from_coords(coords: List[Tuple[int, np.ndarray]], L: int, thresh: float):
    has = np.zeros(L, dtype=bool); xyz = np.zeros((L,3), dtype=np.float32)
    for i,c in coords:
        if 0<=i<L: has[i]=True; xyz[i]=c
    idx = np.where(has)[0]
    contact = np.zeros((L,L), dtype=bool)
    if len(idx)>0:
        sub = xyz[idx]
        d = np.sqrt(((sub[:,None,:]-sub[None,:,:])**2).sum(-1))
        c = d < thresh
        for a,ia in enumerate(idx):
            for b,ib in enumerate(idx):
                contact[ia,ib]=c[a,b]
    return contact, has


In [3]:

class FallbackEmbedder(torch.nn.Module):
    def __init__(self, dim=64, vocab=26):
        super().__init__(); self.emb = torch.nn.Embedding(vocab, dim)
        torch.nn.init.xavier_uniform_(self.emb.weight)
    def forward(self, seq: str) -> torch.Tensor:
        idx = torch.tensor([(ord(ch)%26) for ch in seq], dtype=torch.long, device=DEVICE)
        return self.emb(idx)

def try_load_esm2(model_id: str):
    try:
        from transformers import AutoModel, AutoTokenizer
        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, local_files_only=False)
        mdl = AutoModel.from_pretrained(model_id, trust_remote_code=True, local_files_only=False).to(DEVICE).eval()
        return tok, mdl
    except Exception as e:
        print("[warn] HF load failed:", e); return None, FallbackEmbedder().to(DEVICE).eval()

@torch.no_grad()
def embed_sequence(seq: str, tokenizer, model) -> torch.Tensor:
    if tokenizer is None: return model(seq)
    toks = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
    toks = {k: v.to(DEVICE) for k,v in toks.items()}
    out = model(**toks)
    H = out.last_hidden_state[0]
    if H.shape[0] >= len(seq)+2: H = H[1:1+len(seq)]
    return H.detach()


In [4]:

class PairMLP(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int=128, dropout: float=0.1):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(d_in, hidden), torch.nn.ReLU(), torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden, 1),
        )
    def forward(self, x): return self.net(x).squeeze(-1)

def build_pair_batch(H: torch.Tensor, pair_idx: np.ndarray) -> torch.Tensor:
    if pair_idx.size==0: return torch.empty(0, device=DEVICE)
    hi = H[pair_idx[:,0]]; hj = H[pair_idx[:,1]]
    return torch.cat([hi, hj, torch.abs(hi-hj), hi*hj], dim=-1)


In [5]:

def process_structure(path: Path, tokenizer, esm2_model):
    try:
        s = load_structure(path)
        chain_seqs   = extract_atom_seq_by_chain(s)
        chain_coords = extract_ca_coords_by_chain(s)
        if not chain_seqs: return None

        seqs, H_blocks, contacts_blocks, masks_blocks, chain_ids = [], [], [], [], []
        for cid in sorted(chain_seqs.keys()):
            seq = chain_seqs[cid]; L = len(seq)
            coords = chain_coords.get(cid, [])
            if MAX_LEN_PER_CHAIN is not None and L > MAX_LEN_PER_CHAIN:
                seq = seq[:MAX_LEN_PER_CHAIN]
                coords = [(i,c) for (i,c) in coords if i<MAX_LEN_PER_CHAIN]
                L = len(seq)

            cmat, has = contact_map_from_coords(coords, L, CA_DIST_THRESH)
            idx = np.where(has)[0]
            if len(idx)==0: continue
            valid = np.zeros((L,L), dtype=bool)
            for i in idx:
                for j in idx:
                    if i!=j: valid[i,j]=True

            H = embed_sequence(seq, tokenizer, esm2_model)
            seqs.append(seq); H_blocks.append(H); contacts_blocks.append(cmat); masks_blocks.append(valid)
            chain_ids.extend([cid]*L)

        if not seqs: return None
        Ls = [len(s) for s in seqs]; Ltot=sum(Ls)
        contact_full = np.zeros((Ltot,Ltot), dtype=bool)
        valid_full   = np.zeros((Ltot,Ltot), dtype=bool)
        ci=0
        for k,L in enumerate(Ls):
            cj=ci+L
            contact_full[ci:cj, ci:cj] = contacts_blocks[k]
            valid_full  [ci:cj, ci:cj] = masks_blocks[k]
            ci=cj
        H_full = torch.cat(H_blocks, dim=0).to(DEVICE)
        uniq = {c:i for i,c in enumerate(sorted(set(chain_ids)))}
        chain_ids_arr = np.array([uniq[c] for c in chain_ids], dtype=np.int64)

        return {"seq":"".join(seqs),"chain_ids":chain_ids_arr,"contact":contact_full,
                "valid_pair":valid_full,"H":H_full,"pdb_id":path.stem,"path":str(path)}
    except Exception as e:
        print("Error processing", path, ":", e); return None


In [6]:

def sample_pairs(valid_mask: np.ndarray, max_pairs: Optional[int]) -> np.ndarray:
    idx = np.argwhere(valid_mask)
    if idx.size==0: return idx
    if (max_pairs is not None) and (len(idx)>max_pairs):
        sel = np.random.choice(len(idx), size=max_pairs, replace=False); idx = idx[sel]
    return idx

def train_one_epoch(model, opt, train_structs, pair_cap, batch_pairs=BATCH_PAIRS):
    model.train(); tot=0.0
    bce = torch.nn.BCEWithLogitsLoss()
    for S in train_structs:
        if S is None: continue
        H = S["H"]; contact = torch.from_numpy(S["contact"]).to(DEVICE)
        pairs = sample_pairs(S["valid_pair"], pair_cap)
        if pairs.size==0: continue
        for start in range(0, len(pairs), batch_pairs):
            sl = pairs[start:start+batch_pairs]
            X = build_pair_batch(H, sl)
            if X.numel()==0: continue
            y = contact[sl[:,0], sl[:,1]].float()
            logits = model(X); loss = bce(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            tot += float(loss.detach().cpu())
    return tot

@torch.no_grad()
def evaluate_structs(model, structs, eval_batch_pairs=EVAL_BATCH_PAIRS, max_pairs_per_struct=EVAL_MAX_PAIRS_PER_STRUCT):
    model.eval()
    all_scores, all_labels = [], []
    per_pdb = {}
    for S in structs:
        if S is None: continue
        H = S["H"]; contact = torch.from_numpy(S["contact"]).to(DEVICE)
        idx = np.argwhere(S["valid_pair"])
        if idx.size==0: continue
        if (max_pairs_per_struct is not None) and (len(idx)>max_pairs_per_struct):
            sel = np.random.choice(len(idx), size=max_pairs_per_struct, replace=False); idx = idx[sel]

        probs_parts, y_parts = [], []
        for start in range(0, len(idx), eval_batch_pairs):
            sl = idx[start:start+eval_batch_pairs]
            X = build_pair_batch(H, sl)
            logits = model(X)
            probs_parts.append(torch.sigmoid(logits).cpu().numpy())
            y_parts.append(contact[sl[:,0], sl[:,1]].cpu().numpy().astype(int))

        probs  = np.concatenate(probs_parts)
        labels = np.concatenate(y_parts)

        try:
            ap  = average_precision_score(labels, probs)
            roc = roc_auc_score(labels, probs)
        except Exception:
            ap, roc = float("nan"), float("nan")

        L = len(S["seq"]); order = np.argsort(-probs)
        def prec_at(k):
            k = max(1, min(len(order), k)); sel = order[:k]
            return labels[sel].mean()
        per_pdb[S["pdb_id"]] = dict(
            pr_auc=ap, roc_auc=roc, p_at_L=prec_at(L), p_at_L2=prec_at(max(1,L//2)), p_at_L5=prec_at(max(1,L//5))
        )
        all_scores.append(probs); all_labels.append(labels)

    if not all_scores: return dict(global_pr_auc=float("nan"), global_roc_auc=float("nan"), per_pdb=per_pdb)
    scores = np.concatenate(all_scores); labels = np.concatenate(all_labels)
    try:
        g_ap  = average_precision_score(labels, scores)
        g_roc = roc_auc_score(labels, scores)
    except Exception:
        g_ap, g_roc = float("nan"), float("nan")
    return dict(global_pr_auc=g_ap, global_roc_auc=g_roc, per_pdb=per_pdb)


In [7]:

def collect_structs(folder: Path):
    files = []
    for pat in ("*.pdb","*.PDB","*.cif","*.CIF","*.mmcif","*.MMCIF"):
        files += list(folder.glob(pat))
    return sorted(files)

def main():
    global tokenizer, esm2_model
    tokenizer, esm2_model = try_load_esm2(MODEL_ID)
    print("Embedder:", type(esm2_model).__name__)

    all_train_files = collect_structs(PDB_TRAIN_DIR)
    test_files = collect_structs(PDB_TEST_DIR)
    print(f"Found {len(all_train_files)} train, {len(test_files)} test PDBs")
    if MAX_TRAIN_FILES and len(all_train_files)>MAX_TRAIN_FILES:
        all_train_files = all_train_files[:MAX_TRAIN_FILES]
        print(f"Capped training files to {len(all_train_files)}")

    train_files, val_files = train_test_split(all_train_files, test_size=VAL_SPLIT, random_state=SEED)
    print(f"Split: {len(train_files)} train / {len(val_files)} val")

    print("Processing training structures...")
    train_structs = [process_structure(p, tokenizer, esm2_model) for p in tqdm(train_files, desc="embed/train")]
    train_structs = [S for S in train_structs if S is not None]
    print("Processing validation structures...")
    val_structs = [process_structure(p, tokenizer, esm2_model) for p in tqdm(val_files, desc="embed/val")]
    val_structs = [S for S in val_structs if S is not None]

    if not train_structs:
        print("No valid training structures."); return

    d_in = train_structs[0]["H"].shape[-1]*4
    model = PairMLP(d_in=d_in, hidden=128, dropout=0.1).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

    for ep in range(1, EPOCHS+1):
        loss = train_one_epoch(model, opt, train_structs, pair_cap=PAIR_SUBSAMPLE_TRAIN)
        print(f"[train] epoch {ep} loss {loss:.3f}")

    val_scores = evaluate_structs(model, val_structs)
    print("Val: PR-AUC={:.4f}, ROC-AUC={:.4f}".format(val_scores["global_pr_auc"], val_scores["global_roc_auc"]))

    # quick full-train
    full_structs = train_structs + val_structs
    model = PairMLP(d_in=d_in, hidden=128, dropout=0.1).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
    for ep in range(1, EPOCHS+1):
        loss = train_one_epoch(model, opt, full_structs, pair_cap=PAIR_SUBSAMPLE_TRAIN)
        print(f"[full] epoch {ep} loss {loss:.3f}")

    torch.save({"model_state": model.state_dict(), "d_in": d_in, "cfg": dict(hidden=128, dropout=0.1, lr=1e-3)}, SAVE_PATH)
    print("Saved:", SAVE_PATH)

    if test_files:
        print("Processing test structures...")
        test_structs = [process_structure(p, tokenizer, esm2_model) for p in tqdm(test_files, desc="embed/test")]
        test_structs = [S for S in test_structs if S is not None]
        test_scores = evaluate_structs(model, test_structs)
        print("Test: PR-AUC={:.4f}, ROC-AUC={:.4f}".format(test_scores["global_pr_auc"], test_scores["global_roc_auc"]))
        df = pd.DataFrame.from_dict(test_scores["per_pdb"], orient="index").reset_index().rename(columns={"index":"pdb_id"})
        out = Path("results/test_metrics.csv"); out.parent.mkdir(exist_ok=True); df.to_csv(out, index=False)
        print("Saved per-PDB metrics to:", out)
    else:
        print("No test PDBs; skipped test.")


In [8]:

# Run the smooth pipeline
if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/miniconda3/envs/.env_res_contact/lib/python3.11/site-packages/torchvision/image.so
  warn(
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedder: EsmModel
Found 15000 train, 500 test PDBs
Capped training files to 60
Split: 48 train / 12 val
Processing training structures...


embed/train: 100%|██████████| 48/48 [00:15<00:00,  3.08it/s]


Processing validation structures...


embed/val: 100%|██████████| 12/12 [00:03<00:00,  3.63it/s]


[train] epoch 1 loss 19.386
Val: PR-AUC=0.1855, ROC-AUC=0.7741
[full] epoch 1 loss 23.779
Saved: models/rescontact_best.pt
Processing test structures...


embed/test: 100%|██████████| 500/500 [09:03<00:00,  1.09s/it]


Test: PR-AUC=0.1224, ROC-AUC=0.7391
Saved per-PDB metrics to: results/test_metrics.csv
