# Residue Contact Prediction — Fixed (FINAL2)

In [None]:

# Residue Contact Prediction — Fixed (FINAL2)
# - Forces float32 end-to-end (H, logits, labels)
# - Embedding cache loaded as float32 (prevents fp16 leakage into Linear)
# - Robust forward_pairs (idx as np or torch; empty-safe)
# - Pairs/labels yielded on DEVICE; train loop recasts H->float32 on DEVICE
# - BCEWithLogitsLoss target/shape alignment before loss
# - Dummy path if no PDBs (Biopython optional)

import os, gc, hashlib, random
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, roc_auc_score
from tqdm import tqdm

# ---- Optional Biopython (guarded) ----
try:
    from Bio.PDB import PDBParser, MMCIFParser, PPBuilder
    BIO_AVAILABLE = True
except Exception:
    BIO_AVAILABLE = False
    PDBParser = MMCIFParser = PPBuilder = None

# ---- Config ----
PDB_TRAIN_DIR = Path("../data/pdb/train")
PDB_TEST_DIR  = Path("../data/pdb/test")
CACHE_DIR = Path("./cache/embeds"); CACHE_DIR.mkdir(parents=True, exist_ok=True)

MODEL_ID = "facebook/esm2_t6_8M_UR50D"
CA_DIST_THRESH = 8.0

MAX_TRAIN_FILES = 200
MAX_LEN_PER_CHAIN = 600
VAL_SPLIT = 0.2
SEED = 42

MAX_POS_PER_STRUCT = 5000
NEG_RATIO = 3
RANK = 128
LR = 1e-3
EPOCHS = 5
PAIRS_BATCH = 50000
EVAL_CHUNK = 200000

SAVE_DIR = Path("./models"); SAVE_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = SAVE_DIR / "rescontact_bilin.pt"

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print("Using device:", DEVICE)

torch.set_default_dtype(torch.float32)  # ensure default ops in fp32

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

# ---- ESM2 loader + fallback + cache ----
def _chain_cache_key(seq: str, model_id: str) -> str:
    return hashlib.sha1((seq + "|" + model_id).encode()).hexdigest()[:16] + ".npz"

def _cache_path_for(seq: str, model_id: str) -> Path:
    return CACHE_DIR / _chain_cache_key(seq, model_id)

def try_load_esm2(model_id: str):
    try:
        from transformers import AutoModel, AutoTokenizer
        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, local_files_only=False)
        mdl = AutoModel.from_pretrained(model_id, trust_remote_code=True, local_files_only=False)
        mdl.to(DEVICE).eval()
        return tok, mdl
    except Exception as e:
        print(f"[warn] Could not load {model_id}: {e}\nUsing fallback random embedder.")
        class Fallback(torch.nn.Module):
            def __init__(self, dim=64, vocab=26):
                super().__init__()
                self.emb = torch.nn.Embedding(vocab, dim)
                torch.nn.init.xavier_uniform_(self.emb.weight)
            def forward(self, s: str):
                idx = torch.tensor([(ord(c) % 26) for c in s], dtype=torch.long, device=DEVICE)
                return self.emb(idx)
        return None, Fallback().to(DEVICE).eval()

@torch.no_grad()
def embed_sequence(seq: str, tokenizer, model) -> torch.Tensor:
    cp = _cache_path_for(seq, MODEL_ID)
    if cp.exists():
        z = np.load(cp)
        # FORCE float32 from cache to avoid half precision going into Linear
        return torch.from_numpy(z["emb"]).to(DEVICE, dtype=torch.float32)
    if tokenizer is None:
        H = model(seq)  # fallback produces float32
    else:
        toks = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
        toks = {k: v.to(DEVICE) for k,v in toks.items()}
        out = model(**toks)
        H = out.last_hidden_state[0]
        if H.shape[0] >= len(seq)+2:
            H = H[1:1+len(seq)]
    # Save fp16 to disk (small), but return fp32 for training
    H16 = H.detach().to(torch.float16).cpu().numpy()
    np.savez_compressed(cp, emb=H16)
    return H.detach().to(DEVICE, dtype=torch.float32)

# ---- PDB parsing & labels ----
def _parser_for(path: Path):
    if not BIO_AVAILABLE:
        raise RuntimeError("Biopython not available; cannot parse PDB/mmCIF. Use dummy mode.")
    ext = path.suffix.lower()
    if ext == ".pdb": return PDBParser(QUIET=True)
    if ext in (".cif",".mmcif"): return MMCIFParser(QUIET=True)
    raise ValueError(f"Unsupported: {path}")

def load_structure(path: Path):
    return _parser_for(path).get_structure(path.stem, str(path))

def extract_atom_seq_by_chain(struct) -> Dict[str,str]:
    if not BIO_AVAILABLE: return {}
    seqs = {}
    ppb = PPBuilder()
    model = next(iter(struct))
    for chain in model:
        polypeps = list(ppb.build_peptides(chain, aa_only=False))
        if not polypeps: continue
        seq = "".join([str(pp.get_sequence()) for pp in polypeps])
        if not seq: continue
        if (MAX_LEN_PER_CHAIN is not None) and (len(seq) > MAX_LEN_PER_CHAIN):
            seq = seq[:MAX_LEN_PER_CHAIN]
        seqs[chain.id] = seq
    return seqs

def extract_ca_coords_by_chain(struct) -> Dict[str, List[Tuple[int, np.ndarray]]]:
    if not BIO_AVAILABLE: return {}
    chain_coords = {}
    ppb = PPBuilder()
    model = next(iter(struct))
    for chain in model:
        polypeps = list(ppb.build_peptides(chain))
        if not polypeps: continue
        coords, off = [], 0
        for pp in polypeps:
            for i,res in enumerate(pp):
                if "CA" in res:
                    if (MAX_LEN_PER_CHAIN is None) or (off+i < MAX_LEN_PER_CHAIN):
                        coords.append((off+i, res["CA"].coord.copy()))
            off += len(pp)
        if coords: chain_coords[chain.id] = coords
    return chain_coords

def contact_map_from_coords(coords: List[Tuple[int, np.ndarray]], L: int, thresh: float):
    has = np.zeros(L, dtype=bool)
    xyz = np.zeros((L,3), dtype=np.float32)
    for i,c in coords:
        if 0 <= i < L:
            has[i] = True; xyz[i] = c
    idx = np.where(has)[0]
    contact = np.zeros((L,L), dtype=bool)
    valid   = np.zeros((L,L), dtype=bool)
    if len(idx) > 0:
        sub = xyz[idx]
        d = np.sqrt(((sub[:,None,:]-sub[None,:,:])**2).sum(-1))
        c = (d < thresh)
        for a,ia in enumerate(idx):
            for b,ib in enumerate(idx):
                if ia != ib:
                    contact[ia,ib] = c[a,b]
                    valid[ia,ib] = True
    return contact, valid

# ---- Structure → seq/H/contact ----
def process_structure(path: Path, tokenizer, esm2_model):
    try:
        st = load_structure(path)
        chain_seqs = extract_atom_seq_by_chain(st)
        chain_coords = extract_ca_coords_by_chain(st)
        if not chain_seqs: return None

        seqs, Hs, Cs, Vs, chain_ids = [], [], [], [], []
        for cid in sorted(chain_seqs.keys()):
            seq = chain_seqs[cid]
            L = len(seq)
            C, V = contact_map_from_coords(chain_coords.get(cid, []), L, CA_DIST_THRESH)
            if V.sum() == 0:   # no coord pairs
                continue
            H = embed_sequence(seq, tokenizer, esm2_model)  # returns float32
            Hs.append(H); seqs.append(seq); Cs.append(C); Vs.append(V)
            chain_ids.extend([cid]*L)

        if not Hs: return None

        full_seq = "".join(seqs)
        H_full = torch.cat(Hs, dim=0).to(DEVICE, dtype=torch.float32)
        Ls = [len(s) for s in seqs]; Ltot = sum(Ls)
        C_full = np.zeros((Ltot,Ltot), dtype=bool)
        V_full = np.zeros((Ltot,Ltot), dtype=bool)
        ci = 0
        for k,L in enumerate(Ls):
            cj = ci+L
            C_full[ci:cj,ci:cj] = Cs[k]
            V_full[ci:cj,ci:cj] = Vs[k]
            ci = cj

        uniq = {c:i for i,c in enumerate(sorted(set(chain_ids)))}
        chain_ids_arr = np.array([uniq[c] for c in chain_ids], dtype=np.int64)
        return {"seq": full_seq, "H": H_full, "contact": C_full, "valid_pair": V_full,
                "chain_ids": chain_ids_arr, "pdb_id": path.stem, "path": str(path)}
    except Exception as e:
        print(f"[warn] skip {path}: {e}")
        return None

# ---- Balanced sampler + Bilinear scorer ----
def build_train_pairs_balanced(S, max_pos=None, neg_ratio=3, upper_tri=True):
    C = S["contact"]; V = S["valid_pair"]
    if upper_tri:
        tri = np.triu(np.ones_like(C, dtype=bool), k=1)
    else:
        tri = np.ones_like(C, dtype=bool); np.fill_diagonal(tri, False)
    pos = C & V & tri
    neg = (~C) & V & tri
    pos_idx = np.argwhere(pos)
    neg_idx = np.argwhere(neg)
    if len(pos_idx) == 0:
        return np.empty((0,2), dtype=int), np.empty((0,), dtype=np.float32)
    if (max_pos is not None) and (len(pos_idx) > max_pos):
        pos_idx = pos_idx[np.random.choice(len(pos_idx), size=max_pos, replace=False)]
    m = min(len(neg_idx), int(len(pos_idx)*neg_ratio))
    if m > 0:
        neg_idx = neg_idx[np.random.choice(len(neg_idx), size=m, replace=False)]
    else:
        neg_idx = np.empty((0,2), dtype=int)
    pairs = np.vstack([pos_idx, neg_idx])
    labels = np.hstack([np.ones(len(pos_idx), dtype=np.float32),
                        np.zeros(len(neg_idx), dtype=np.float32)])
    return pairs, labels

class BilinearScorer(torch.nn.Module):
    def __init__(self, d, rank=128):
        super().__init__()
        self.U = torch.nn.Linear(d, rank, bias=False)
        self.V = torch.nn.Linear(d, rank, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
        self.to(torch.float32)
    def forward_pairs(self, H, idx):
        # Accept numpy or torch indices
        if isinstance(idx, torch.Tensor):
            if idx.numel() == 0:
                return torch.empty((0,), device=H.device, dtype=torch.float32)
            i = idx[:,0].to(H.device).long()
            j = idx[:,1].to(H.device).long()
        else:
            arr = np.asarray(idx)
            if arr.size == 0:
                return torch.empty((0,), device=H.device, dtype=torch.float32)
            i = torch.as_tensor(arr[:,0], device=H.device, dtype=torch.long)
            j = torch.as_tensor(arr[:,1], device=H.device, dtype=torch.long)
        H = H.to(dtype=torch.float32)  # ensure fp32 for Linear
        A = self.U(H); B = self.V(H)
        s = (A[i]*B[j]).sum(-1) + self.bias
        return s.to(dtype=torch.float32).view(-1)  # ensure (N,) fp32

# ---- Training helpers ----
def iterate_balanced_pairs(structs, max_pos=None, neg_ratio=3.0, batch_pairs=50000, shuffle=True):
    """Yields (H, pairs, labels) with pairs/labels already on DEVICE."""
    for S in structs:
        if S is None:
            continue
        pairs, labels = build_train_pairs_balanced(S, max_pos=max_pos, neg_ratio=neg_ratio)
        if len(pairs) == 0:
            continue
        H = S["H"]  # will be moved/casted in the train loop
        pairs_tensor = torch.from_numpy(pairs).to(DEVICE)              # long later
        labels_tensor = torch.from_numpy(labels).to(DEVICE)            # float later
        if shuffle:
            perm = torch.randperm(len(pairs_tensor), device=pairs_tensor.device)
            pairs_tensor = pairs_tensor[perm]
            labels_tensor = labels_tensor[perm]
        for i in range(0, len(pairs_tensor), batch_pairs):
            batch_pairs_tensor = pairs_tensor[i:i+batch_pairs]
            batch_labels_tensor = labels_tensor[i:i+batch_pairs]
            yield H, batch_pairs_tensor, batch_labels_tensor

def _align_logits_targets_f32(logits: torch.Tensor, y: torch.Tensor):
    dev = logits.device
    logits = logits.to(device=dev, dtype=torch.float32)
    y = y.to(device=dev, dtype=torch.float32)
    if (torch.min(y) < 0) or (torch.max(y) > 1):
        y = (y > 0).to(torch.float32)
    if logits.ndim == 2 and logits.size(1) == 1 and y.ndim == 1:
        y = y.unsqueeze(1)
    if logits.shape != y.shape:
        try:
            y = y.view_as(logits)
        except Exception as e:
            raise RuntimeError(f"[ALIGN] Shape mismatch: logits {logits.shape}, target {y.shape}") from e
    return logits, y

def train_one_epoch_balanced(
    model, opt, train_structs,
    max_pos=None, neg_ratio: float = 3.0, batch_pairs: int = 50000,
    pos_weight=None, amp: bool = False, debug_first_batch: bool = False,
):
    """BCEWithLogitsLoss-safe loop: float32 everywhere + aligned shapes/devices."""
    model.train(); total_loss = 0.0
    # BCE loss
    if pos_weight is not None:
        pos_weight_tensor = torch.tensor([pos_weight], device=DEVICE, dtype=torch.float32)
        bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
    else:
        bce = nn.BCEWithLogitsLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=(amp and DEVICE.type=="cuda"))
    first = True
    for H, pairs, labels in iterate_balanced_pairs(train_structs, max_pos=max_pos, neg_ratio=neg_ratio, batch_pairs=batch_pairs):
        # Move/cast inputs
        H = H.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
        pairs = pairs.to(device=DEVICE, dtype=torch.long,    non_blocking=True)
        labels = labels.to(device=DEVICE, dtype=torch.float32, non_blocking=True).view(-1,1)
        # Forward + loss
        if amp and DEVICE.type=="cuda":
            with torch.cuda.amp.autocast():
                logits = model.forward_pairs(H, pairs)
                logits, labels = _align_logits_targets_f32(logits, labels)
                loss = bce(logits, labels)
            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            logits = model.forward_pairs(H, pairs)
            logits, labels = _align_logits_targets_f32(logits, labels)
            loss = bce(logits, labels)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        if debug_first_batch and first:
            print(f"[DEBUG] H:{H.dtype}/{H.device}/{tuple(H.shape)}  pairs:{pairs.dtype}/{pairs.device}/{tuple(pairs.shape)}")
            print(f"[DEBUG] logits:{logits.dtype}/{logits.device}/{tuple(logits.shape)}  labels:{labels.dtype}/{labels.device}/{tuple(labels.shape)}")
            first = False
        total_loss += float(loss.detach().cpu())
    return total_loss

@torch.no_grad()
def evaluate_structs_streamed(model, structs, chunk=200000):
    model.eval()
    per_pdb = {}; all_p=[]; all_y=[]
    for S in structs:
        if S is None: continue
        H = S["H"].to(DEVICE, dtype=torch.float32)
        C = S["contact"]; V = S["valid_pair"]
        tri = np.triu(np.ones_like(C, dtype=bool), k=1)
        mask = V & tri
        idx = np.argwhere(mask)
        if idx.size == 0: continue
        probs_list=[]; labels_list=[]
        for st in range(0, len(idx), chunk):
            ib = idx[st:st+chunk]
            logits = model.forward_pairs(H, ib)  # accepts numpy indices; returns fp32
            probs  = torch.sigmoid(logits).detach().cpu().numpy()
            labs   = C[ib[:,0], ib[:,1]].astype(int)
            probs_list.append(probs); labels_list.append(labs)
        probs = np.concatenate(probs_list); labels = np.concatenate(labels_list)
        try:
            ap = average_precision_score(labels, probs)
            roc= roc_auc_score(labels, probs)
        except Exception:
            ap, roc = float("nan"), float("nan")
        L = len(S["seq"]); order = probs.argsort()[::-1]
        def p_at(k):
            k = max(1, min(len(order), k)); sel = order[:k]
            return labels[sel].mean()
        per_pdb[S.get("pdb_id","?")] = dict(pr_auc=ap, roc_auc=roc, p_at_L=p_at(L), p_at_L2=p_at(max(1,L//2)), p_at_L5=p_at(max(1,L//5)))
        all_p.append(probs); all_y.append(labels)
        del probs_list, labels_list, probs, labels; gc.collect()
    if not all_p:
        return dict(global_pr_auc=float("nan"), global_roc_auc=float("nan"), per_pdb=per_pdb)
    P = np.concatenate(all_p); Y = np.concatenate(all_y)
    try:
        g_ap = average_precision_score(Y, P); g_roc = roc_auc_score(Y, P)
    except Exception:
        g_ap, g_roc = float("nan"), float("nan")
    return dict(global_pr_auc=g_ap, global_roc_auc=g_roc, per_pdb=per_pdb)

# ---- Dummy data ----
def create_dummy_structure(pdb_id, seq_length=100, d_model=320):
    seq = "A" * seq_length
    H = torch.randn(seq_length, d_model, device=DEVICE, dtype=torch.float32)
    contact = np.random.rand(seq_length, seq_length) < 0.1
    np.fill_diagonal(contact, False)
    valid_pair = np.triu(np.ones((seq_length, seq_length), dtype=bool), k=1)
    return {"seq": seq, "H": H, "contact": contact, "valid_pair": valid_pair,
            "chain_ids": np.zeros(seq_length, dtype=np.int64), "pdb_id": pdb_id, "path": f"dummy_{pdb_id}.pdb"}

# ---- Main ----
def main():
    def collect_files(folder: Path):
        pats = ("*.pdb","*.PDB","*.cif","*.CIF","*.mmcif","*.MMCIF")
        files=[]; [files.extend(folder.glob(p)) for p in pats]
        return sorted(files)

    print("Loading ESM2 model...")
    tokenizer, esm2_model = try_load_esm2(MODEL_ID)
    print("Embedder:", type(esm2_model).__name__)

    train_all = collect_files(PDB_TRAIN_DIR)
    test_files = collect_files(PDB_TEST_DIR)
    print(f"Found {len(train_all)} train, {len(test_files)} test files.")

    if len(train_all) == 0:
        print("WARNING: No training files found! Using dummy data for demonstration.")
        train_structs = [create_dummy_structure(f"train_{i}") for i in range(5)]
        val_structs   = [create_dummy_structure(f"val_{i}") for i in range(2)]
        test_structs  = [create_dummy_structure(f"test_{i}") for i in range(2)]
    else:
        if (MAX_TRAIN_FILES is not None) and (len(train_all) > MAX_TRAIN_FILES):
            train_all = train_all[:MAX_TRAIN_FILES]; print(f"Capped train files to {len(train_all)}")
        train_files, val_files = train_test_split(train_all, test_size=VAL_SPLIT, random_state=SEED)
        print(f"Split: {len(train_files)} train / {len(val_files)} val")
        print("Processing train/val (caches embeddings per chain on first run)...")
        train_structs = [process_structure(p, tokenizer, esm2_model) for p in tqdm(train_files)]
        val_structs   = [process_structure(p, tokenizer, esm2_model) for p in tqdm(val_files)]
        train_structs = [s for s in train_structs if s is not None]
        val_structs   = [s for s in val_structs if s is not None]
        if len(train_structs) == 0:
            print("WARNING: No training structures found! Creating dummy data.")
            train_structs = [create_dummy_structure(f"train_{i}") for i in range(5)]
            val_structs   = [create_dummy_structure(f"val_{i}") for i in range(2)]

    d = int(train_structs[0]["H"].shape[-1]); print(f"Embedding dimension: {d}")
    model = BilinearScorer(d=d, rank=RANK).to(DEVICE).to(torch.float32)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)

    print("Starting training...")
    for ep in range(1, EPOCHS + 1):
        loss = train_one_epoch_balanced(model, opt, train_structs, debug_first_batch=(ep == 1))
        val_scores = evaluate_structs_streamed(model, val_structs)
        print(f"[epoch {ep}] loss={loss:.3f}  val PR-AUC={val_scores['global_pr_auc']:.4f}  ROC-AUC={val_scores['global_roc_auc']:.4f}")

    torch.save({"state_dict": model.state_dict(), "cfg": {"d": d, "rank": RANK, "lr": LR, "epochs": EPOCHS}}, str(MODEL_PATH))
    print("Saved:", MODEL_PATH)

    print("Processing test...")
    if len(test_files) > 0:
        test_structs = [process_structure(p, tokenizer, esm2_model) for p in tqdm(test_files)]
        test_structs = [s for s in test_structs if s is not None]
    else:
        test_structs = [create_dummy_structure(f"test_{i}") for i in range(2)]
    scores = evaluate_structs_streamed(model, test_structs)
    print("TEST  PR-AUC={:.4f}  ROC-AUC={:.4f}".format(scores["global_pr_auc"], scores["global_roc_auc"]))

if __name__ == "__main__":
    main()
