### Residue–Residue Contact Prediction (ESM2 ± MSA, Local/Online switch)

**Purpose**: Train a binary contact predictor from *sequence only*. Ground‑truth contact maps come from PDB/ mmCIF coordinates (Cα–Cα < 8 Å). The network never sees target coordinates as input.

### Highlights
- **ESM2 embeddings** (Hugging Face) — lightweight model by default.
- **MSA features optional**: per‑residue 20‑AA frequencies + entropy and optional MI+APC pair feature.
- **Local or Online MSAs**: 
  - Local: `../data/msa/{train|test}/<pdbid>_<chain>.a3m|.fasta|.fa`
  - Online: BLASTp (NCBI) or jackhmmer (EBI) fetched on‑the‑fly and kept **in memory**.
- **Memory‑friendly training** for 8‑GB Macs: balanced pair sampling, bilinear scorer (low‑rank), chunked eval.

### Folder layout
```
Res-contact/
├─ data/
│  ├─ pdb/
│  │  ├─ train/   # PDB/mmCIF used for training
│  │  └─ test/    # held‑out structures
│  └─ msa/
│     ├─ train/   # optional local A3M/FASTA (query first)
│     └─ test/
```

**Tip for M‑series Mac (8 GB)**: Keep caps small at first (few files, short max length, small pair subsample). Enable MPS automatically when available.


In [None]:
# =====================
# Config (edit freely)
# =====================
from pathlib import Path
import os

# --- Data ---
PDB_TRAIN_DIR = Path("../data/pdb/train")
PDB_TEST_DIR  = Path("../data/pdb/test")

# --- ESM2 backbone ---
# Tiny model first; upgrade later if needed (e.g., facebook/esm2_t12_35M_UR50D)
ESM2_MODEL_ID = "facebook/esm2_t6_8M_UR50D"

# --- Contacts ---
CA_DIST_THRESH = 8.0  # Å

# --- Train/Val/Test split ---
VAL_SPLIT = 0.2    # split inside train directory by files
SEED = 42

# --- Smooth mode knobs (for 8 GB) ---
MAX_TRAIN_FILES = 50         # cap number of training files processed
MAX_LEN_PER_CHAIN = 600      # truncate long chains (None for no cap)
PAIR_SUBSAMPLE_TRAIN = 50_000  # pairs per structure per epoch (balanced sampler overrides semantics)

# --- Eval chunking ---
EVAL_MAX_PAIRS_PER_STRUCT = 500_000
EVAL_BATCH_PAIRS = 100_000

# --- MSA FLAGS (this is the requested switch cell) ---
USE_LOCAL_MSA  = True            # if A3M/FA exists at ../data/msa/{split}/<pdbid>_<chain>.* use it
USE_ONLINE_MSA = True            # if no local file, try online (BLAST or jackhmmer)
MSA_PROVIDER   = "ncbi"          # "ncbi" (BLASTp) or "jackhmmer"
MSA_MAX_SEQS   = 32              # cap to keep small RAM
MSA_EVALUE     = 1e-5
MSA_TIMEOUT_S  = 180
os.environ.setdefault("ENTREZ_EMAIL", "ryoji.takahashi@gmail.com")  # NCBI courtesy
os.environ.setdefault("MSA_EMAIL", "ryoji.takahashi@gmail.com")     # EBI jackhmmer

# --- MSA feature thresholds ---
MIN_MSA_DEPTH_FOR_MI = 10   # require at least this many sequences to compute MI
MAX_L_FOR_MI        = 800   # don't compute MI if sequence is longer than this

# --- Model / training ---
EPOCHS = 3                 # small for first pass
LR = 1e-3
BILINEAR_RANK = 128
MI_WEIGHT = 0.5            # scale MI pair feature into the logit (if available)

# --- Saving ---
SAVE_PATH = Path("models/rescontact_best.pt")
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)

print("Config loaded.")

Config loaded.


In [2]:
# =====================
# Imports & device
# =====================
import numpy as np
import torch
from typing import Dict, List, Tuple, Optional
from Bio.PDB import PDBParser, MMCIFParser, PPBuilder
from Bio import SeqIO
from Bio import pairwise2
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, roc_auc_score
from tqdm import tqdm
import pandas as pd
import warnings, io, time, json, requests

torch.manual_seed(SEED)
np.random.seed(SEED)

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
DEVICE



device(type='mps')

#### MSA I/O — Local files (A3M/FASTA) or Online (BLAST / jackhmmer)

In [3]:
def _a3m_strip_insertions(s: str) -> str:
    return "".join([c for c in s if not c.islower()])

def read_a3m_as_rows(path: Path, query_seq: str):
    with open(path, "r") as f:
        lines = [ln.strip() for ln in f if ln.strip()]
    seqs, cur = [], []
    for ln in lines:
        if ln.startswith(">"):
            if cur:
                seqs.append("".join(cur)); cur = []
        else:
            cur.append(ln)
    if cur: seqs.append("".join(cur))
    seqs = [_a3m_strip_insertions(s) for s in seqs]
    if not seqs: return [query_seq]
    aligned_query = seqs[0]
    L = len(query_seq)
    col2qi, qi = [], 0
    for c in aligned_query:
        if c != "-": col2qi.append(qi); qi += 1
        else:        col2qi.append(None)
    rows = [query_seq]
    for aligned in seqs[1:1+MSA_MAX_SEQS]:
        row = ["-"]*L
        for col, aa in enumerate(aligned):
            q = col2qi[col]
            if q is not None and q < L and aa != "-":
                row[q] = aa.upper()
        rows.append("".join(row))
    return rows

def read_fasta_as_rows(path: Path, query_seq: str):
    records = list(SeqIO.parse(str(path), "fasta"))
    if not records: return [query_seq]
    rows = [query_seq]; L = len(query_seq)
    for rec in records[1:1+MSA_MAX_SEQS]:
        aln = pairwise2.align.globalms(query_seq, str(rec.seq), 2, -1, -4, -1, one_alignment_only=True)
        if not aln: continue
        aq, asub, *_ = aln[0]
        row = ["-"]*L; qi = 0
        for cq, cs in zip(aq, asub):
            if cq != "-":
                if cs != "-": row[qi] = cs.upper()
                qi += 1
        rows.append("".join(row))
    return rows

def read_alignment_as_rows(path: Path, query_seq: str):
    ext = path.suffix.lower()
    if ext == ".a3m": return read_a3m_as_rows(path, query_seq)
    if ext in (".fasta", ".fa"): return read_fasta_as_rows(path, query_seq)
    warnings.warn(f"Unknown MSA ext {ext}; returning query only")
    return [query_seq]

def msa_from_blast_hsps(query_seq: str, hitlist_size: int, evalue: float):
    try:
        from Bio.Blast import NCBIWWW, NCBIXML
    except Exception as e:
        warnings.warn(f"BLAST unavailable: {e}"); return [query_seq]
    L = len(query_seq); rows = [query_seq]
    try:
        handle = NCBIWWW.qblast("blastp", "nr", query_seq, hitlist_size=hitlist_size, expect=evalue, format_type="XML")
        data = handle.read(); handle.close()
        for record in NCBIXML.parse(io.StringIO(data)):
            for aln in record.alignments[:hitlist_size]:
                if not aln.hsps: continue
                hsp = max(aln.hsps, key=lambda h: h.identities)
                aq, asub = hsp.query, hsp.sbjct
                row = ["-"]*L; qi = hsp.query_start - 1
                for c_q, c_s in zip(aq, asub):
                    if c_q != "-":
                        if 0 <= qi < L and c_s != "-": row[qi] = c_s.upper()
                        qi += 1
                rows.append("".join(row))
    except Exception as e:
        warnings.warn(f"BLAST query failed: {e}")
    return rows[:(1+MSA_MAX_SEQS)]

def msa_from_jackhmmer(query_seq: str, max_seqs: int, email: str, timeout_s: int = 180):
    base = "https://www.ebi.ac.uk/Tools/hmmer"
    rows = [query_seq]
    try:
        r = requests.post(f"{base}/search/jackhmmer", data={"seq": query_seq, "email": email}, timeout=20)
        r.raise_for_status(); job = r.json()["uuid"]
        t0 = time.time()
        while True:
            s = requests.get(f"{base}/result/{job}", timeout=20).json()
            if s.get("finished", False): break
            if time.time() - t0 > timeout_s: raise TimeoutError("jackhmmer timed out")
            time.sleep(3)
        sto = requests.get(f"{base}/result/{job}/aln", timeout=20).text
        seqs = []
        for line in sto.splitlines():
            if not line or line.startswith(("#","//")): continue
            toks = line.split()
            if len(toks) >= 2: seqs.append(toks[1].replace(".", "-"))
        if not seqs: return rows
        L = len(query_seq); aligned_query = seqs[0]
        col2qi, qi = [], 0
        for c in aligned_query:
            if c != "-": col2qi.append(qi); qi += 1
            else:        col2qi.append(None)
        for aligned in seqs[1:1+max_seqs]:
            row = ["-"]*L
            for col, aa in enumerate(aligned):
                q = col2qi[col]
                if q is not None and q < L and aa != "-": row[q] = aa.upper()
            rows.append("".join(row))
    except Exception as e:
        warnings.warn(f"jackhmmer failed: {e}")
    return rows

def load_msa_rows(chain_key: str, seq: str, split: str):
    # 1) local
    if USE_LOCAL_MSA:
        base = Path(f"../data/msa/{split}")
        for ext in (".a3m", ".fasta", ".fa"):
            p = base / f"{chain_key}{ext}"
            if p.exists(): return read_alignment_as_rows(p, seq)
    # 2) online
    if USE_ONLINE_MSA:
        if MSA_PROVIDER == "ncbi":
            return msa_from_blast_hsps(seq, hitlist_size=MSA_MAX_SEQS, evalue=MSA_EVALUE)
        elif MSA_PROVIDER == "jackhmmer":
            return msa_from_jackhmmer(seq, max_seqs=MSA_MAX_SEQS, email=os.environ.get("MSA_EMAIL",""), timeout_s=MSA_TIMEOUT_S)
    # 3) fallback
    return [seq]

AA = "ACDEFGHIKLMNPQRSTVWY"
AA_IDX = {a:i for i,a in enumerate(AA)}

def msa_1d_features(rows):
    L = len(rows[0])
    if len(rows) <= 1: return np.zeros((L,21), dtype=np.float32)
    depth = len(rows); freq = np.zeros((L,20), dtype=np.float32)
    for r in rows:
        for i, aa in enumerate(r):
            if aa in AA_IDX: freq[i, AA_IDX[aa]] += 1
    freq /= float(depth)
    with np.errstate(divide='ignore', invalid='ignore'):
        ent = -np.sum(freq * np.where(freq>0, np.log(freq+1e-12), 0.0), axis=1)
    return np.concatenate([freq, ent[:,None].astype(np.float32)], axis=1)

def _mi_from_counts(C):
    P = C / (np.sum(C) + 1e-9)
    pi = P.sum(1, keepdims=True); pj = P.sum(0, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        MI = np.nansum(P * (np.log(P+1e-12) - np.log(pi+1e-12) - np.log(pj+1e-12)))
    return float(MI)

def msa_pair_mi_apc(rows):
    L = len(rows[0]); depth = len(rows)
    if depth < MIN_MSA_DEPTH_FOR_MI or L > MAX_L_FOR_MI: return None
    X = np.full((depth, L), -1, dtype=np.int16)
    for r, row in enumerate(rows):
        for i, aa in enumerate(row):
            X[r,i] = AA_IDX.get(aa, -1)
    MI = np.zeros((L,L), dtype=np.float32)
    for i in range(L):
        for j in range(i+1, L):
            C = np.zeros((20,20), dtype=np.float32)
            for r in range(depth):
                a, b = X[r,i], X[r,j]
                if a>=0 and b>=0: C[a,b] += 1
            val = _mi_from_counts(C)
            MI[i,j] = MI[j,i] = val
    row_mean = MI.mean(axis=1, keepdims=True)
    col_mean = MI.mean(axis=0, keepdims=True)
    glob = MI.mean()
    MI_apc = MI - row_mean*col_mean/(glob+1e-9)
    np.fill_diagonal(MI_apc, 0.0)
    return MI_apc


#### Structure parsing → ATOM sequences & CA contacts

In [4]:
def _parser_for(path: Path):
    ext = path.suffix.lower()
    if ext == ".pdb": return PDBParser(QUIET=True)
    if ext in (".cif", ".mmcif"): return MMCIFParser(QUIET=True)
    raise ValueError(f"Unsupported structure format: {path}")

def load_structure(path: Path):
    return _parser_for(path).get_structure(path.stem, str(path))

def extract_atom_seq_by_chain(struct):
    seqs = {}; ppb = PPBuilder(); model = next(iter(struct))
    for chain in model:
        polypeps = list(ppb.build_peptides(chain, aa_only=False))
        if not polypeps: continue
        s = "".join([str(pp.get_sequence()) for pp in polypeps])
        if s: seqs[chain.id] = s
    return seqs

def extract_ca_coords_by_chain(struct):
    ppb = PPBuilder(); chain_coords = {}; model = next(iter(struct))
    for chain in model:
        polypeps = ppb.build_peptides(chain)
        if not polypeps: continue
        coords = []; offset = 0
        for pp in polypeps:
            for i, res in enumerate(pp):
                if "CA" in res: coords.append((offset+i, res["CA"].coord.copy()))
            offset += len(pp)
        if coords: chain_coords[chain.id] = coords
    return chain_coords

def contact_map_from_coords(coords, L, thresh):
    has = np.zeros(L, dtype=bool); xyz = np.zeros((L,3), dtype=np.float32)
    for i,c in coords:
        if i < L: has[i] = True; xyz[i] = c
    idx = np.where(has)[0]
    C = np.zeros((L,L), dtype=bool)
    if len(idx) > 0:
        sub = xyz[idx]
        d = np.sqrt(((sub[:,None,:]-sub[None,:,:])**2).sum(-1))
        c = d < thresh
        for a, ia in enumerate(idx):
            for b, ib in enumerate(idx):
                C[ia, ib] = c[a,b]
    V = np.zeros((L,L), dtype=bool)
    for i in idx:
        for j in idx:
            if i!=j: V[i,j] = True
    return C, V, has


#### Embeddings (ESM2 via HF) with in‑memory cache

In [5]:
class FallbackEmbedder(torch.nn.Module):
    def __init__(self, dim=64, vocab=26):
        super().__init__(); self.emb = torch.nn.Embedding(vocab, dim)
        torch.nn.init.xavier_uniform_(self.emb.weight)
    def forward(self, seq: str):
        idx = torch.tensor([(ord(c)%26) for c in seq], dtype=torch.long, device=DEVICE)
        return self.emb(idx)

def try_load_esm2(model_id: str):
    try:
        from transformers import AutoModel, AutoTokenizer
        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        mdl = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(DEVICE).eval()
        return tok, mdl
    except Exception as e:
        warnings.warn(f"Could not load {model_id}: {e}; using fallback embedder.")
        return None, FallbackEmbedder().to(DEVICE).eval()

@torch.no_grad()
def embed_sequence(seq: str, tokenizer, model):
    if tokenizer is None: return model(seq)
    tokens = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
    tokens = {k: v.to(DEVICE) for k,v in tokens.items()}
    out = model(**tokens)
    H = out.last_hidden_state[0]
    if H.shape[0] >= len(seq)+2: H = H[1:1+len(seq)]
    return H.detach()

tokenizer, esm2_model = try_load_esm2(ESM2_MODEL_ID)
print("Embedder:", type(esm2_model).__name__)


  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /opt/miniconda3/envs/.env_res_contact/lib/python3.11/site-packages/torchvision/image.so
  warn(
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedder: EsmModel


#### Build per‑structure tensors (concat chains)

In [6]:
EMBED_CACHE = {}

def process_structure(path: Path, split: str):
    try:
        s = load_structure(path)
        chain_seqs = extract_atom_seq_by_chain(s)
        chain_coords = extract_ca_coords_by_chain(s)
        if not chain_seqs: return None

        H_blocks = []; contact_blocks=[]; valid_blocks=[]; mi_tiles=[]; Ls=[]; chain_ids=[]; seqs=[]
        for cid in sorted(chain_seqs.keys()):
            seq = chain_seqs[cid]
            if MAX_LEN_PER_CHAIN is not None and len(seq) > MAX_LEN_PER_CHAIN:
                seq = seq[:MAX_LEN_PER_CHAIN]
            L = len(seq)
            coords = [(i,c) for (i,c) in chain_coords.get(cid, []) if i < L]
            if L==0 or len(coords)<2: continue

            key = f"{path.stem}_{cid}_{L}"
            if key in EMBED_CACHE: H = EMBED_CACHE[key]
            else:
                H = embed_sequence(seq, tokenizer, esm2_model)
                EMBED_CACHE[key] = H

            rows = load_msa_rows(f"{path.stem}_{cid}", seq, split)
            feat1d = msa_1d_features(rows)  # [L,21]
            H_aug = torch.cat([H, torch.from_numpy(feat1d).to(H.device)], dim=1)

            C, V, has = contact_map_from_coords(coords, L, CA_DIST_THRESH)
            MI = msa_pair_mi_apc(rows)

            H_blocks.append(H_aug); contact_blocks.append(C); valid_blocks.append(V); mi_tiles.append(MI)
            Ls.append(L); seqs.append(seq); chain_ids.extend([cid]*L)

        if not H_blocks: return None
        H_full = torch.cat(H_blocks, dim=0); Ltot = sum(Ls)
        C_full = np.zeros((Ltot,Ltot), dtype=bool); V_full = np.zeros((Ltot,Ltot), dtype=bool)
        MI_full = None
        if any(m is not None for m in mi_tiles): MI_full = np.zeros((Ltot,Ltot), dtype=np.float32)
        ci = 0
        for k, L in enumerate(Ls):
            cj = ci+L
            C_full[ci:cj, ci:cj] = contact_blocks[k]
            V_full[ci:cj, ci:cj] = valid_blocks[k]
            if MI_full is not None and mi_tiles[k] is not None:
                MI_full[ci:cj, ci:cj] = mi_tiles[k]
            ci = cj

        uniq = {c:i for i,c in enumerate(sorted(set(chain_ids)))}
        chain_ids_arr = np.array([uniq[c] for c in chain_ids], dtype=np.int64)
        return dict(seq="".join(seqs), chain_ids=chain_ids_arr, contact=C_full, valid_pair=V_full,
                    H=H_full.to(DEVICE), MI=MI_full, pdb_id=path.stem, path=str(path))
    except Exception as e:
        warnings.warn(f"Error processing {path}: {e}")
        return None


#### Model: Low‑rank Bilinear scorer (+ optional MI term)

In [7]:
class BilinearScorer(torch.nn.Module):
    def __init__(self, d: int, rank: int = 128, mi_weight: float = MI_WEIGHT):
        super().__init__()
        self.U = torch.nn.Linear(d, rank, bias=False)
        self.V = torch.nn.Linear(d, rank, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.mi_weight = mi_weight
    def forward_pairs(self, H: torch.Tensor, idx: np.ndarray, MI: Optional[np.ndarray] = None):
        i = torch.as_tensor(idx[:,0], device=H.device)
        j = torch.as_tensor(idx[:,1], device=H.device)
        A = self.U(H); B = self.V(H)
        logits = (A[i] * B[j]).sum(-1) + self.bias
        if MI is not None:
            mi_val = torch.from_numpy(MI[idx[:,0], idx[:,1]]).to(logits.device)
            logits = logits + self.mi_weight * mi_val
        return logits


#### Training utils: balanced pair sampling, batching & evaluation

In [8]:
def build_train_pairs_balanced(S, max_pos: int = 5000, neg_ratio: int = 3):
    C = S["contact"]; V = S["valid_pair"]; L = C.shape[0]
    iu, ju = np.triu_indices(L, k=1)
    pos_mask = C[iu, ju] & V[iu, ju]
    neg_mask = (~C[iu, ju]) & V[iu, ju]
    pos_idx = np.stack([iu[pos_mask], ju[pos_mask]], axis=1)
    neg_idx = np.stack([iu[neg_mask], ju[neg_mask]], axis=1)
    if len(pos_idx)==0: return np.empty((0,2), dtype=int), np.empty((0,), dtype=np.float32)
    if max_pos is not None and len(pos_idx) > max_pos:
        sel = np.random.choice(len(pos_idx), size=max_pos, replace=False)
        pos_idx = pos_idx[sel]
    m = min(len(neg_idx), len(pos_idx)*neg_ratio)
    if m>0:
        sel = np.random.choice(len(neg_idx), size=m, replace=False)
        neg_idx = neg_idx[sel]
    else:
        neg_idx = np.empty((0,2), dtype=int)
    pairs = np.vstack([pos_idx, neg_idx])
    labels = np.hstack([np.ones(len(pos_idx), dtype=np.float32), np.zeros(len(neg_idx), dtype=np.float32)])
    return pairs, labels

def collect_struct_files(folder: Path):
    files = []
    for pat in ("*.pdb","*.PDB","*.cif","*.CIF","*.mmcif","*.MMCIF"):
        files += list(folder.glob(pat))
    return sorted(files)

@torch.no_grad()
def evaluate_model(model, structs):
    model.eval()
    all_scores, all_labels = [], []
    per_pdb = {}
    for S in structs:
        if S is None: continue
        H = S["H"]; C = S["contact"]; V = S["valid_pair"]; MI = S["MI"]
        iu, ju = np.triu_indices(C.shape[0], k=1)
        mask = V[iu, ju]
        iu, ju = iu[mask], ju[mask]
        idx = np.stack([iu, ju], axis=1)
        if len(idx)==0: continue
        probs_list, labels_list = [], []
        for s in range(0, len(idx), EVAL_BATCH_PAIRS):
            sl = idx[s:s+EVAL_BATCH_PAIRS]
            logits = model.forward_pairs(H, sl, MI)
            probs_list.append(torch.sigmoid(logits).detach().cpu().numpy())
            labels_list.append(C[sl[:,0], sl[:,1]].astype(int))
        probs = np.concatenate(probs_list); labels = np.concatenate(labels_list)
        try:
            ap = average_precision_score(labels, probs); roc = roc_auc_score(labels, probs)
        except Exception:
            ap, roc = float('nan'), float('nan')
        L = len(S["seq"]); order = np.argsort(-probs)
        def p_at(k): k = max(1, min(k, len(order))); return labels[order[:k]].mean()
        per_pdb[S["pdb_id"]] = dict(pr_auc=ap, roc_auc=roc, p_at_L=p_at(L), p_at_L2=p_at(L//2), p_at_L5=p_at(max(1,L//5)))
        all_scores.append(probs); all_labels.append(labels)
    if not all_scores: return dict(global_pr_auc=float('nan'), global_roc_auc=float('nan'), per_pdb=per_pdb)
    scores = np.concatenate(all_scores); labels = np.concatenate(all_labels)
    try:
        g_ap = average_precision_score(labels, scores); g_roc = roc_auc_score(labels, scores)
    except Exception:
        g_ap, g_roc = float('nan'), float('nan')
    return dict(global_pr_auc=g_ap, global_roc_auc=g_roc, per_pdb=per_pdb)


#### Train loop (balanced)

In [9]:
def train_epoch(model, structs, lr=LR):
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    bce = torch.nn.BCEWithLogitsLoss()
    model.train(); total = 0.0
    for S in structs:
        if S is None: continue
        H = S["H"]; MI = S["MI"]
        pairs, y = build_train_pairs_balanced(S, max_pos=PAIR_SUBSAMPLE_TRAIN, neg_ratio=3)
        if len(pairs)==0: continue
        for s in range(0, len(pairs), EVAL_BATCH_PAIRS):
            sl = pairs[s:s+EVAL_BATCH_PAIRS]
            logits = model.forward_pairs(H, sl, MI)
            yy = torch.from_numpy(y[s:s+EVAL_BATCH_PAIRS]).to(logits.device)
            loss = bce(logits, yy)
            opt.zero_grad(); loss.backward(); opt.step()
            total += float(loss.detach().cpu())
    return total


#### Run: split train/val, train model, evaluate on test

In [None]:
train_files_all = collect_struct_files(PDB_TRAIN_DIR)
test_files = collect_struct_files(PDB_TEST_DIR)
print(f"Found {len(train_files_all)} train files, {len(test_files)} test files.")
if MAX_TRAIN_FILES is not None and len(train_files_all) > MAX_TRAIN_FILES:
    train_files_all = train_files_all[:MAX_TRAIN_FILES]
    print(f"Capped training files to {len(train_files_all)}")

train_files, val_files = train_test_split(train_files_all, test_size=VAL_SPLIT, random_state=SEED)
print(f"Split: {len(train_files)} train / {len(val_files)} val")

print("Processing train...")
train_structs = [process_structure(p, split="train") for p in tqdm(train_files)]
train_structs = [s for s in train_structs if s is not None]
print("Processing val...")
val_structs = [process_structure(p, split="train") for p in tqdm(val_files)]
val_structs = [s for s in val_structs if s is not None]

if not train_structs:
    raise RuntimeError("No valid train structures parsed. Check your data paths.")

d_in = train_structs[0]["H"].shape[-1]
model = BilinearScorer(d=d_in, rank=BILINEAR_RANK, mi_weight=MI_WEIGHT).to(DEVICE)

for ep in range(1, EPOCHS+1):
    loss = train_epoch(model, train_structs, lr=LR)
    val_scores = evaluate_model(model, val_structs)
    print(f"epoch {ep}  loss={loss:.3f}  val PR-AUC={val_scores['global_pr_auc']:.4f}  ROC-AUC={val_scores['global_roc_auc']:.4f}")

torch.save({"state_dict": model.state_dict(), "cfg": dict(d_in=d_in, rank=BILINEAR_RANK, mi_weight=MI_WEIGHT, lr=LR, epochs=EPOCHS)}, str(SAVE_PATH))
print("Saved model to", SAVE_PATH)

print("Processing test...")
test_structs = [process_structure(p, split="test") for p in tqdm(test_files)]
test_structs = [s for s in test_structs if s is not None]
test_scores = evaluate_model(model, test_structs)
print("Test: PR-AUC={:.4f}, ROC-AUC={:.4f}".format(test_scores['global_pr_auc'], test_scores['global_roc_auc']))

df = pd.DataFrame.from_dict(test_scores["per_pdb"], orient="index").reset_index().rename(columns={"index":"pdb_id"})
df.sort_values("pr_auc", ascending=False, inplace=True)
out_csv = RESULTS_DIR / "test_metrics.csv"
df.to_csv(out_csv, index=False)
print("Wrote", out_csv)

Found 15000 train files, 500 test files.
Capped training files to 50
Split: 40 train / 10 val
Processing train...


 55%|█████▌    | 22/40 [1:04:30<35:54, 119.69s/it]

#### Notes & Reporting
- **Method**: ESM2 per‑residue embeddings concatenated with MSA 1D features; low‑rank bilinear pair scorer; optional MI+APC term.
- **Data steps**: ATOM polypeptides (PPBuilder) → per‑chain sequences; CA contacts @ 8Å; concat chains block‑diagonally; balanced pos/neg sampling.
- **Hyperparams**: see Config cell.
- **Metrics**: PR‑AUC, ROC‑AUC, Precision@L/L2/L5 per PDB and global.
- **MSA switch**: flip `USE_LOCAL_MSA`/`USE_ONLINE_MSA`/`MSA_PROVIDER` in Config — the rest of the pipeline stays unchanged.
