<a href="https://colab.research.google.com/github/rhizome18/rhizome18.github.io/blob/main/AKM_fold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
# Build-Your-Own AlphaFold‑Like Predictor – Stage 1 + Stage 2
# =====================================================================
# Author: Anand / ChatGPT • 28 Apr 2025
# ---------------------------------------------------------------------
# Stage 1: raw‑data pipeline – create (seq, Cα) NPZ files from CATH S40.
# Stage 2: tiny neural net that learns Cα geometry from sequence only.
# ---------------------------------------------------------------------
# Run in Google Colab (GPU runtime recommended).
# =====================================================================

# %% [markdown]
"""
## Milestones
1. **Stage 1** ⇒ `NPZ saved` > 0 and loader sample lengths match.
2. **Stage 2** ⇒ See training loss decreasing and RMSD print‑out.
"""

# %%
# SECTION 0 – Install & imports
# ---------------------------------------------------------------------
!pip -q install biopython torch==2.* tqdm py3Dmol

import os, gzip, shutil, urllib.request, ssl, re, math
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np, torch, torch.nn as nn
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB import is_aa

ssl._create_default_https_context = ssl._create_unverified_context

DATA_DIR = Path("/content/protein_data").resolve(); DATA_DIR.mkdir(exist_ok=True)
print("Data dir:", DATA_DIR)
print("CUDA available:", torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %%
# SECTION 1 – Download CATH S40 list
# ---------------------------------------------------------------------
URLS = [
    "http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/cath-dataset-nonredundant-S40-v4_3_0.list",
    "http://ftp.ebi.ac.uk/pub/databases/cath/non-redundant-data-sets/cath-dataset-nonredundant-S40-v4_3_0.list",
]
LIST_PATH = DATA_DIR / "cath_s40.list"
for u in URLS:
    try:
        urllib.request.urlretrieve(u, LIST_PATH); print("✔ downloaded", u); break
    except Exception as e:
        print("⚠", e)
else:
    raise RuntimeError("CATH list download failed")

entries = []
for line in open(LIST_PATH):
    dom = line.split()[0] if line.strip() else ""
    if re.fullmatch(r"[0-9A-Za-z]{4}[A-Za-z0-9][0-9]{2}", dom):
        pdb, chain = dom[:4].lower(), dom[4]
        entries.append((pdb, chain, dom))
print("Total domains:", len(entries))

# %%
# SECTION 2 – mmCIF downloader
# ---------------------------------------------------------------------
PDB_URL = "https://files.rcsb.org/download/{}.cif.gz"

def fetch_cif(pid):
    gz = DATA_DIR / f"{pid}.cif.gz"; cif = DATA_DIR / f"{pid}.cif"
    if cif.exists():
        return cif
    if not gz.exists():
        urllib.request.urlretrieve(PDB_URL.format(pid.upper()), gz)
    with gzip.open(gz, "rb") as fi, open(cif, "wb") as fo:
        shutil.copyfileobj(fi, fo)
    return cif

# Quick smoke‑test
for pid, _, _ in entries[:5]:
    fetch_cif(pid)
print("✅ mmCIF download test passed")

# %%
# SECTION 3 – Parse → NPZ
# ---------------------------------------------------------------------
parser = MMCIFParser(QUIET=True)
NPZ_DIR = DATA_DIR / "npz"; NPZ_DIR.mkdir(exist_ok=True)

MIN_LEN = 5
NONSTD = {"MSE":"MET","FME":"MET","MEN":"MET","SEP":"SER","TPO":"THR","PTR":"TYR","CSO":"CYS","SEC":"CYS","HYP":"PRO","3CL":"LYS","MLY":"LYS"}
STD3_TO_1 = {
    "ALA":"A","ARG":"R","ASN":"N","ASP":"D","CYS":"C","GLU":"E","GLN":"Q","GLY":"G","HIS":"H","ILE":"I","LEU":"L","LYS":"K","MET":"M","PHE":"F","PRO":"P","SER":"S","THR":"T","TRP":"W","TYR":"Y","VAL":"V"
}

saved, failed = 0, []
for pdb, want, dom in tqdm(entries[:200], desc="Stage 1: parse 200"):
    try:
        model = parser.get_structure(pdb, fetch_cif(pdb))[0]
        chain = next((c for c in model if c.id.strip().upper() == want.upper()), None)
        if chain is None and len(model) == 1:
            chain = next(iter(model))
        if chain is None:
            failed.append((dom, "no chain")); continue
        coords, res3, seen = [], [], set()
        for res in chain:
            if not is_aa(res, standard=False):
                continue
            idx = res.id[1]
            if idx in seen or "CA" not in res:
                continue
            seen.add(idx)
            coords.append(res["CA"].coord)
            tri = res.get_resname().upper(); tri = NONSTD.get(tri, tri)
            res3.append(tri)
        if len(res3) < MIN_LEN:
            failed.append((dom, "short")); continue
        seq_one = "".join(STD3_TO_1.get(t, "X") for t in res3)
        np.savez_compressed(
            NPZ_DIR / f"{pdb}_{chain.id.strip() or ' '}.npz",
            seq=np.frombuffer(seq_one.encode(), dtype='S1'),
            coords=np.array(coords, dtype=np.float32)
        )
        saved += 1
    except Exception as e:
        failed.append((dom, str(e)))
print(f"NPZ saved: {saved}   Failures: {len(failed)}")

# %%
# SECTION 4 – Dataset & loader check
# ---------------------------------------------------------------------
class Prot(torch.utils.data.Dataset):
    AA = "ACDEFGHIKLMNPQRSTVWY"; aa2i = {a: i + 1 for i, a in enumerate(AA)}
    def __init__(self, folder):
        self.files = list(Path(folder).glob("*.npz"))
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        d = np.load(self.files[idx])
        seq = [self.aa2i.get(c.decode(), 0) for c in d['seq']]
        return torch.tensor(seq), torch.tensor(d['coords'])

loader = torch.utils.data.DataLoader(Prot(NPZ_DIR), batch_size=1, shuffle=True)
print("Loader batches:", len(loader))
seq, xyz = next(iter(loader))
print("Seq/XYZ shapes:", seq.shape, xyz.shape)


Data dir: /content/protein_data
CUDA available: True
✔ downloaded http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/cath-dataset-nonredundant-S40-v4_3_0.list
Total domains: 31885
✅ mmCIF download test passed


Stage 1: parse 200:   0%|          | 0/200 [00:00<?, ?it/s]

NPZ saved: 200   Failures: 0
Loader batches: 164
Seq/XYZ shapes: torch.Size([1, 389]) torch.Size([1, 389, 3])


In [23]:
# ============================================================
#  STAGE 2  ·  MiniFold – single-sequence distance-map learner
#  (Block ID: stage2_minifold)
# ============================================================

# ---- Hyper-parameters --------------------------------------
EMBED_DIM  = 64
N_LAYERS   = 2
LR         = 3e-4
EPOCHS     = 100          # demo – raise for better learning
BATCH_SIZE = 8
MAX_LEN    = 512        # crop / pad to this length

# ---- Collate fn -------------------------------------------
from torch.nn.functional import pad as tpad
def collate(batch):
    seqs, coords = zip(*batch)
    L = min(MAX_LEN, max(len(s) for s in seqs))
    pad_seq, pad_xyz, masks = [], [], []
    for s, c in zip(seqs, coords):
        s, c = s[:L], c[:L]
        pad_seq.append(tpad(s, (0, L - len(s))))
        pad_xyz.append(tpad(c, (0, 0, 0, L - len(c))))
        masks.append(torch.arange(L) < len(s))
    return torch.stack(pad_seq), torch.stack(pad_xyz), torch.stack(masks)

train_loader = torch.utils.data.DataLoader(
    Prot(NPZ_DIR), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate
)

# ---- MiniFold model ----------------------------------------
import torch.nn as nn
class MiniFold(nn.Module):
    def __init__(self, vocab=21, dim=EMBED_DIM, layers=N_LAYERS):
        super().__init__()
        self.embed = nn.Embedding(vocab, dim, padding_idx=0)
        enc = nn.TransformerEncoderLayer(
            dim, nhead=4, dim_feedforward=dim*4, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
        self.coord_head = nn.Linear(dim, 3)

    def forward(self, seq, mask):
        x = self.embed(seq)
        x = self.encoder(x, src_key_padding_mask=~mask)
        return self.coord_head(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MiniFold().to(device)
opt    = torch.optim.Adam(model.parameters(), lr=LR)

# ---- Distance-loss helpers ---------------------------------
def pairwise_dists(x):
    diff = x[:, :, None, :] - x[:, None, :, :]
    return torch.linalg.vector_norm(diff, dim=-1)

# ---- Training loop -----------------------------------------
for epoch in range(EPOCHS):
    model.train(); running = 0
    for seq, xyz_gt, mask in train_loader:
        seq, xyz_gt, mask = seq.to(device), xyz_gt.to(device), mask.to(device)
        pred_xyz = model(seq, mask)

        d_pred = pairwise_dists(pred_xyz)
        d_gt   = pairwise_dists(xyz_gt)
        pmask  = (mask.float()[:, :, None] * mask.float()[:, None, :]).bool()

        loss = ((d_pred - d_gt)[pmask] ** 2).mean()
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}   loss = {running / len(train_loader):.4f}")

# ---- Quick RMSD sanity check -------------------------------
model.eval()
with torch.no_grad():
    seq, xyz_gt, mask = next(iter(train_loader))
    pred_xyz = model(seq.to(device), mask.to(device)).cpu()
    rmsd = torch.sqrt(((pred_xyz - xyz_gt)**2).sum(-1).mean()).item()
    print(f"Backbone RMSD (single batch): {rmsd:.2f} Å")


Epoch 1/100   loss = 833.8588
Epoch 2/100   loss = 738.7631
Epoch 3/100   loss = 648.1308
Epoch 4/100   loss = 571.1488
Epoch 5/100   loss = 557.4504
Epoch 6/100   loss = 537.6607
Epoch 7/100   loss = 526.6514
Epoch 8/100   loss = 514.6407
Epoch 9/100   loss = 494.6375
Epoch 10/100   loss = 461.1045
Epoch 11/100   loss = 447.8661
Epoch 12/100   loss = 447.8212
Epoch 13/100   loss = 419.4497
Epoch 14/100   loss = 412.5819
Epoch 15/100   loss = 380.9445
Epoch 16/100   loss = 383.9218
Epoch 17/100   loss = 372.1361
Epoch 18/100   loss = 361.6285
Epoch 19/100   loss = 364.6108
Epoch 20/100   loss = 347.1488
Epoch 21/100   loss = 324.8620
Epoch 22/100   loss = 337.4598
Epoch 23/100   loss = 328.1169
Epoch 24/100   loss = 341.6473
Epoch 25/100   loss = 340.8503
Epoch 26/100   loss = 337.9175
Epoch 27/100   loss = 337.7351
Epoch 28/100   loss = 322.2369
Epoch 29/100   loss = 315.8418
Epoch 30/100   loss = 331.2799
Epoch 31/100   loss = 319.8882
Epoch 32/100   loss = 313.3160
Epoch 33/100   lo

In [24]:
# --- Kabsch alignment -------------------------------------------------
def kabsch(P, Q, mask=None):
    """Align coords P to Q; return aligned P."""
    if mask is not None:
        P, Q = P[mask], Q[mask]
    P_cent = P - P.mean(0)
    Q_cent = Q - Q.mean(0)
    C = P_cent.T @ Q_cent
    V, S, Wt = torch.linalg.svd(C)
    R = V @ Wt
    # Fix improper rotation
    if torch.det(R) < 0:
        V[:, -1] *= -1
        R = V @ Wt
    return (P_cent @ R) + Q.mean(0)

# ---- Aligned RMSD sanity check --------------------------------
model.eval()
with torch.no_grad():
    seq, xyz_gt, mask = next(iter(train_loader))
    pred_xyz = model(seq.to(device), mask.to(device)).cpu()

    P_aligned = kabsch(pred_xyz[0], xyz_gt[0], mask[0])
    rmsd_aligned = torch.sqrt(
        ((P_aligned - xyz_gt[0][mask[0]]) ** 2).mean()
    ).item()

    print(f"Aligned backbone RMSD: {rmsd_aligned:.2f} Å")


Aligned backbone RMSD: 17.43 Å
