<a href="https://colab.research.google.com/github/rhizome18/rhizome18.github.io/blob/main/AKM_fold_pass1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
#  STAGE 1 – FULL DATASET PIPELINE  (Block ID: stage1_full)
#  • Robust gzip retry + MAX_DOMAINS parameter
#  • Defines NPZ_DIR and Prot dataset for later stages
# ============================================================

!pip -q install biopython torch==2.* tqdm py3Dmol

import os, gzip, shutil, urllib.request, ssl, re, concurrent.futures, time
from pathlib import Path
from tqdm.auto import tqdm
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB import is_aa
import numpy as np, torch

ssl._create_default_https_context = ssl._create_unverified_context
DATA_DIR = Path('/content/protein_data').resolve(); DATA_DIR.mkdir(parents=True, exist_ok=True)
NPZ_DIR  = DATA_DIR / 'npz';               NPZ_DIR.mkdir(parents=True,  exist_ok=True)
print('CUDA available:', torch.cuda.is_available())

# ---- 1. Download CATH S40 list ----------------------------
URLS = [
    'http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/cath-dataset-nonredundant-S40-v4_3_0.list',
    'http://ftp.ebi.ac.uk/pub/databases/cath/non-redundant-data-sets/cath-dataset-nonredundant-S40-v4_3_0.list',
]
LIST = DATA_DIR / 'cath_s40.list'
for u in URLS:
    try:
        urllib.request.urlretrieve(u, LIST); print('✔ list', u); break
    except Exception as e:
        print('⚠', e)
else:
    raise RuntimeError('CATH list download failed')

entries = []
for line in open(LIST):
    dom = line.split()[0] if line.strip() else ''
    if re.fullmatch(r'[0-9A-Za-z]{4}[A-Za-z0-9][0-9]{2}', dom):
        pdb, chain = dom[:4].lower(), dom[4]
        entries.append((pdb, chain, dom))
print('Total domains in list:', len(entries))

# ---- MAX_DOMAINS parameter --------------------------------
MAX_DOMAINS = 200               # set e.g. 5000 or None for full
subset = entries if MAX_DOMAINS is None else entries[:MAX_DOMAINS]
print('Domains selected:', len(subset))

# ---- 2. Robust mmCIF downloader ---------------------------
PDB_URL = 'https://files.rcsb.org/download/{}.cif.gz'
def fetch_cif(pid: str, retries: int = 3):
    gz  = DATA_DIR / f'{pid}.cif.gz'
    cif = DATA_DIR / f'{pid}.cif'
    if cif.exists():
        return cif
    for attempt in range(retries):
        try:
            if not gz.exists():
                urllib.request.urlretrieve(PDB_URL.format(pid.upper()), gz)
            with gzip.open(gz, 'rb') as fi, open(cif, 'wb') as fo:
                shutil.copyfileobj(fi, fo)
            return cif
        except Exception as e:
            print(f'⚠ {pid} retry {attempt+1}/{retries} ({e})')
            if gz.exists():
                gz.unlink()
            time.sleep(1 + attempt)
    raise RuntimeError(f'Failed to fetch {pid}')

print('Downloading mmCIFs …')
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex:
    list(tqdm(ex.map(lambda e: fetch_cif(e[0]), subset), total=len(subset)))

# ---- 3. Parse → NPZ --------------------------------------
parser = MMCIFParser(QUIET=True)
MIN_LEN = 5
NONSTD = {'MSE':'MET','FME':'MET','MEN':'MET','SEP':'SER','TPO':'THR','PTR':'TYR',
          'CSO':'CYS','SEC':'CYS','HYP':'PRO','3CL':'LYS','MLY':'LYS'}
STD3_TO_1 = {'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLU':'E','GLN':'Q',
             'GLY':'G','HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F',
             'PRO':'P','SER':'S','THR':'T','TRP':'W','TYR':'Y','VAL':'V'}

saved = failed = 0
for pdb, want, dom in tqdm(subset, desc='Parsing'):
    try:
        model = parser.get_structure(pdb, fetch_cif(pdb))[0]
        chain = next((c for c in model if c.id.strip().upper() == want.upper()), None)
        if chain is None and len(model) == 1: chain = next(iter(model))
        if chain is None: failed += 1; continue
        coords, res3, seen = [], [], set()
        for res in chain:
            if not is_aa(res, standard=False): continue
            idx = res.id[1]
            if idx in seen or 'CA' not in res: continue
            seen.add(idx); coords.append(res['CA'].coord)
            tri = NONSTD.get(res.get_resname().upper(), res.get_resname().upper())
            res3.append(tri)
        if len(res3) < MIN_LEN: failed += 1; continue
        seq_one = ''.join(STD3_TO_1.get(t, 'X') for t in res3)
        np.savez_compressed(
            NPZ_DIR / f"{pdb}_{chain.id.strip() or ' '}.npz",
            seq=np.frombuffer(seq_one.encode(), dtype='S1'),
            coords=np.array(coords, dtype=np.float32)
        )
        saved += 1
    except Exception:
        failed += 1
print(f'NPZ saved: {saved}   Failures: {failed}')

# ---- Dataset class (needed by Stage 2 / Stage 3) ----------
class Prot(torch.utils.data.Dataset):
    """Load (seq, Cα-coords) pairs from NPZ files."""
    AA = "ACDEFGHIKLMNPQRSTVWY"
    aa2i = {a: i + 1 for i, a in enumerate(AA)}   # 0 = pad
    def __init__(self, folder):
        self.files = list(Path(folder).glob("*.npz"))
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        d = np.load(self.files[idx])
        seq = [self.aa2i.get(ch.decode(), 0) for ch in d["seq"]]
        return torch.tensor(seq), torch.tensor(d["coords"])


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m86.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

  0%|          | 0/200 [00:00<?, ?it/s]

⚠ 1a9x retry 1/3 (Compressed file ended before the end-of-stream marker was reached)
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')
⚠ 1a9x retry 1/3 ([Errno 2] No such file or directory: '/content/protein_data/1a9x.cif.gz')


Parsing:   0%|          | 0/200 [00:00<?, ?it/s]

NPZ saved: 200   Failures: 0


In [2]:
# ============================================================
#  STAGE 2 – MiniFold training on full dataset
#  (Block ID: stage2_minifold_full)
# ============================================================

# ---- Hyper-parameters --------------------------------------
EMBED_DIM  = 256
N_LAYERS   = 8
LR         = 3e-4
EPOCHS     = 20          # full set; adjust as desired
BATCH_SIZE = 8
MAX_LEN    = 256

# ---- Collate fn -------------------------------------------
from torch.nn.functional import pad as tpad
def collate(batch):
    seqs, coords = zip(*batch)
    L = min(MAX_LEN, max(len(s) for s in seqs))
    pad_seq, pad_xyz, masks = [], [], []
    for s, c in zip(seqs, coords):
        s, c = s[:L], c[:L]
        pad_seq.append(tpad(s, (0, L - len(s))))
        pad_xyz.append(tpad(c, (0,0,0, L - len(c))))
        masks.append(torch.arange(L) < len(s))
    return torch.stack(pad_seq), torch.stack(pad_xyz), torch.stack(masks)

train_loader = torch.utils.data.DataLoader(
    Prot(NPZ_DIR), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate
)

# ---- MiniFold model ----------------------------------------
import torch.nn as nn
class MiniFold(nn.Module):
    def __init__(self, vocab=21, dim=EMBED_DIM, layers=N_LAYERS):
        super().__init__()
        self.embed = nn.Embedding(vocab, dim, padding_idx=0)
        enc = nn.TransformerEncoderLayer(dim, nhead=4,
                                         dim_feedforward=dim*4,
                                         batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
        self.coord_head = nn.Linear(dim, 3)
    def forward(self, seq, mask):
        x = self.embed(seq)
        x = self.encoder(x, src_key_padding_mask=~mask)
        return self.coord_head(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MiniFold().to(device)
opt    = torch.optim.Adam(model.parameters(), lr=LR)

# ---- Distance-loss helpers ---------------------------------
def pairwise_dists(x):
    diff = x[:, :, None, :] - x[:, None, :, :]
    return torch.linalg.vector_norm(diff, dim=-1)

# ---- Training loop -----------------------------------------
for epoch in range(EPOCHS):
    model.train(); running = 0
    for seq, xyz_gt, mask in train_loader:
        seq, xyz_gt, mask = seq.to(device), xyz_gt.to(device), mask.to(device)
        pred_xyz = model(seq, mask)

        d_pred = pairwise_dists(pred_xyz)
        d_gt   = pairwise_dists(xyz_gt)
        pmask  = (mask.float()[:, :, None] * mask.float()[:, None, :]).bool()

        loss = ((d_pred - d_gt)[pmask] ** 2).mean()
        opt.zero_grad(); loss.backward(); opt.step()
        running += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}   loss = {running / len(train_loader):.3f}")

# ---- Aligned RMSD quick check ------------------------------
# Kabsch alignment
def kabsch(P, Q, mask):
    P, Q = P[mask], Q[mask]
    Pc = P - P.mean(0); Qc = Q - Q.mean(0)
    C = Pc.T @ Qc
    V,_,Wt = torch.linalg.svd(C); R = V @ Wt
    if torch.det(R) < 0:
        V[:, -1] *= -1; R = V @ Wt
    return (Pc @ R) + Q.mean(0)

model.eval()
with torch.no_grad():
    seq, xyz_gt, mask = next(iter(train_loader))
    pred = model(seq.to(device), mask.to(device)).cpu()
    aligned = kabsch(pred[0], xyz_gt[0], mask[0])
    rmsd = torch.sqrt(((aligned - xyz_gt[0][mask[0]])**2).mean()).item()
    print(f"Aligned backbone RMSD: {rmsd:.2f} Å")


Epoch 1/20   loss = 427.265
Epoch 2/20   loss = 300.626
Epoch 3/20   loss = 256.810
Epoch 4/20   loss = 240.027
Epoch 5/20   loss = 223.626
Epoch 6/20   loss = 218.779
Epoch 7/20   loss = 218.434
Epoch 8/20   loss = 211.425
Epoch 9/20   loss = 213.561
Epoch 10/20   loss = 206.766
Epoch 11/20   loss = 211.190
Epoch 12/20   loss = 210.789
Epoch 13/20   loss = 207.971
Epoch 14/20   loss = 208.290
Epoch 15/20   loss = 203.522
Epoch 16/20   loss = 209.130
Epoch 17/20   loss = 205.643
Epoch 18/20   loss = 208.261
Epoch 19/20   loss = 208.399
Epoch 20/20   loss = 208.116
Aligned backbone RMSD: 16.48 Å


  output = torch._nested_tensor_from_mask(


In [7]:
# ============================================================
#  MiniFold inference → Py3Dmol view + PDB download
# ============================================================

import torch, numpy as np, uuid
from pathlib import Path
import py3Dmol, IPython.display as disp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

# ---- 1. Input sequence -------------------------------------
query_seq = (
    "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTT"
    "FTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKG"
    "IDFKEDGNILGHKLEYNYNSHNVYIMADF"
)  # GFP 238 aa

query_seq = query_seq.replace("\n", "").strip().upper()
L = len(query_seq)

# ---- 2. Tokenise & mask ------------------------------------
ids  = torch.tensor([Prot.aa2i.get(aa, 0) for aa in query_seq]).unsqueeze(0).to(device)
mask = torch.ones(1, L, dtype=torch.bool, device=device)

# ---- 3. Forward pass ---------------------------------------
with torch.no_grad():
    xyz = model(ids, mask).cpu().squeeze(0).numpy()  # [L,3]
xyz -= xyz.mean(0, keepdims=True)                    # center

# ---- 4. 1-letter → 3-letter mapping ------------------------
one2three = {
    "A":"ALA","C":"CYS","D":"ASP","E":"GLU","F":"PHE","G":"GLY","H":"HIS",
    "I":"ILE","K":"LYS","L":"LEU","M":"MET","N":"ASN","P":"PRO","Q":"GLN",
    "R":"ARG","S":"SER","T":"THR","V":"VAL","W":"TRP","Y":"TYR"
}

pdb_lines = []
for i,(aa,coord) in enumerate(zip(query_seq, xyz), 1):
    res3 = one2three.get(aa,"UNK")
    pdb_lines.append(
        f"ATOM  {i:5d}  CA  {res3} A{i:4d}    "
        f"{coord[0]:8.3f}{coord[1]:8.3f}{coord[2]:8.3f}  1.00  0.00           C"
    )
pdb_lines.append("END")
pdb_str = "\n".join(pdb_lines)

# ---- 5. Save & view ----------------------------------------
out = Path(f"/content/minifold_pred_{uuid.uuid4().hex[:6]}.pdb")
out.write_text(pdb_str)
print("PDB saved to", out)

view = py3Dmol.view(width=500, height=400)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()
disp.display(view)          # <— renders in Colab

# ---- 6. Download link --------------------------------------
from google.colab import files; files.download(str(out))


PDB saved to /content/minifold_pred_40d6b8.pdb


<py3Dmol.view at 0x7c44ade6d910>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [10]:
# ============================================================
#  Stage-3 · build A3M alignment cache (slow-but-sure mode)
# ============================================================

import requests, time
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np

# ---- folders ------------------------------------------------
DATA_DIR = Path("/content/protein_data"); DATA_DIR.mkdir(parents=True, exist_ok=True)
NPZ_DIR  = DATA_DIR / "npz"
MSA_DIR  = DATA_DIR / "msa";             MSA_DIR.mkdir(parents=True, exist_ok=True)

# ---- parameters --------------------------------------------
N_ROWS   = 64
MAX_LEN  = 256
API_URL  = "https://api.colabfold.com/v1/mmseqs2/search"
TIMEOUT  = 120           # allow long queue waits
RETRIES  = 8             # generous retry budget
SLEEP    = 5             # base back-off seconds

NPZ_LIST = list(NPZ_DIR.glob("*.npz"))

def mmseqs_query(seq: str) -> str | None:
    """Return A3M string or None if all retries fail."""
    for n in range(1, RETRIES + 1):
        try:
            r = requests.post(API_URL, json={"sequence": seq}, timeout=TIMEOUT)
            if r.status_code == 429:                      # rate-limited
                time.sleep(SLEEP * 2**n); continue
            r.raise_for_status()
            return r.json()["alignment"]
        except requests.exceptions.RequestException:
            time.sleep(SLEEP * 2**n)
    return None

def build_msa(npz_path: Path):
    uid = npz_path.stem
    out = MSA_DIR / f"{uid}.a3m"
    if out.exists(): return True                # already cached

    seq = "".join(ch.decode() for ch in np.load(npz_path)["seq"])
    ali = mmseqs_query(seq)
    if ali is None:
        print("⚠", uid, "skipped after retries"); return False

    # keep first N_ROWS hits, crop columns
    lines, keep = ali.splitlines(), []
    for hdr, s in zip(lines[::2], lines[1::2]):
        keep.extend([hdr, s[:MAX_LEN]])
        if len(keep)//2 >= N_ROWS: break
    with open(out, "w") as fh: fh.write("\n".join(keep) + "\n")
    return True

print(f"Building MSAs serially for {len(NPZ_LIST)} domains …")
success = 0
for npz in tqdm(NPZ_LIST):
    if build_msa(npz):
        success += 1

print(f"✅ MSA cache complete: {success} / {len(NPZ_LIST)} alignments saved → {MSA_DIR}")


Building MSAs serially for 164 domains …


  0%|          | 0/164 [00:00<?, ?it/s]

⚠ 1afp_A skipped after retries


KeyboardInterrupt: 

In [23]:
# ============================================================
#  Stage-3 · local MMseqs2 search  –  plain FASTA workflow
# ============================================================

!apt -q install mmseqs2

from pathlib import Path
import urllib.request, subprocess, tempfile, numpy as np, os, gzip, shutil
from tqdm.auto import tqdm

ROOT   = Path("/content/protein_data")
NPZDIR = ROOT / "npz"
MSADIR = ROOT / "msa";            MSADIR.mkdir(parents=True, exist_ok=True)
TMPDIR = ROOT / "mm_tmp";         TMPDIR.mkdir(parents=True, exist_ok=True)
DBDIR  = ROOT / "ur50_small_db"

GZ_URL   = "https://mmseqs.com/microdatabases/UniRef50_2023_03_small.fasta.gz"
GZ_PATH  = ROOT / "UniRef50_small.fasta.gz"
FASTA    = ROOT / "UniRef50_small.fasta"   # uncompressed ≈ 3 GB

# 1 ── download gzip once ------------------------------------
if not GZ_PATH.exists():
    print("⬇ downloading UniRef50_small.gz (1.3 GB) …")
    urllib.request.urlretrieve(GZ_URL, GZ_PATH)

# 2 ── decompress if needed ----------------------------------
if not FASTA.exists():
    print("🔄 decompressing to plain FASTA (≈3 GB) …")
    with gzip.open(GZ_PATH, "rb") as gi, open(FASTA, "wb") as fo:
        shutil.copyfileobj(gi, fo)

# 3 ── create MMseqs2 DB (once) ------------------------------
if not DBDIR.exists():
    print("🏗  creating MMseqs2 DB from plain FASTA …")
    subprocess.run(["mmseqs", "createdb", FASTA, DBDIR, TMPDIR], check=True)

# 4 ── helper: search & convert ------------------------------
def make_a3m(uid, seq, out_path):
    with tempfile.TemporaryDirectory() as td:
        q = Path(td)/"q.fa"; q.write_text(f">{uid}\n{seq}\n")
        subprocess.run(
            ["mmseqs", "easy-search", q, DBDIR,
             Path(td)/"res", Path(td)/"tmp",
             "-a", "-s", "7.5", "--format-output", "query,target,alnseq"],
            stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
        subprocess.run(
            ["mmseqs", "convertalis", q, DBDIR,
             Path(td)/"res", out_path, "--format-mode", "3"],
            stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

# 5 ── build MSAs for all domains ----------------------------
for npz in tqdm(list(NPZDIR.glob("*.npz")), desc="Searching"):
    uid  = npz.stem
    a3m  = MSADIR / f"{uid}.a3m"
    if a3m.exists(): continue
    seq = "".join(ch.decode() for ch in np.load(npz)["seq"])
    try:
        make_a3m(uid, seq, a3m)
    except subprocess.CalledProcessError:
        print(f"⚠ {uid} – no hits (empty A3M)")

print("✅ local MSA cache ready:",
      len(list(MSADIR.glob('*.a3m'))), "alignments in", MSADIR)


Reading package lists...
Building dependency tree...
Reading state information...
mmseqs2 is already the newest version (13-45111+ds-2).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.
🔄 decompressing to plain FASTA (≈3 GB) …


BadGzipFile: Not a gzipped file (b'\n\n')

In [31]:
# ============================================================
#  Stage-3 · MiniFold-E (ESM-2 embeddings, no MSA)  — dtype-safe
# ============================================================

!pip -q install fair-esm==2.0.0 torch==2.* tqdm

import torch, esm, numpy as np, shutil, gzip, os, tempfile
from pathlib import Path
from tqdm.auto import tqdm

# ---- paths --------------------------------------------------
DATA_DIR = Path("/content/protein_data")
NPZ_DIR  = DATA_DIR / "npz"

# ---- 1 · Load frozen ESM-2 model ---------------------------
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval().cpu()                         # CPU fine for 164 seqs
batch_converter = alphabet.get_batch_converter()

def seq_to_emb(seq: str):
    """Return per-residue ESM embedding  [L,1280]  (float16)"""
    _, _, toks = batch_converter([("p", seq)])
    with torch.no_grad():
        out = esm_model(toks, repr_layers=[33], return_contacts=False)
    emb = out["representations"][33][0, 1:1+len(seq)]
    return emb.cpu().half()                   # save as FP16

# ---- 2 · Cache embeddings in NPZs --------------------------
for npz in tqdm(list(NPZ_DIR.glob("*.npz")), desc="Embedding"):
    with np.load(npz, allow_pickle=False) as d:
        if "emb" in d:        # already done
            continue
        seq = "".join(ch.decode() for ch in d["seq"])
        emb = seq_to_emb(seq).numpy()         # float16
        np.savez_compressed(npz, seq=d["seq"], coords=d["coords"], emb=emb)

print("✅ ESM embeddings cached in NPZs")

# ---- 3 · Dataset wrapper -----------------------------------
class ProtEmb(torch.utils.data.Dataset):
    AA = "ACDEFGHIKLMNPQRSTVWY"
    aa2i = {a:i+1 for i,a in enumerate(AA)}   # pad=0
    def __init__(self, folder):
        self.files = list(Path(folder).glob("*.npz"))
    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        d = np.load(self.files[i])
        seq = [self.aa2i.get(c.decode(), 0) for c in d["seq"]]
        return (torch.tensor(seq),
                torch.tensor(d["coords"]),
                torch.tensor(d["emb"]))

# ---- 4 · Collate (pads & casts embeddings to FP32) ---------
MAX_LEN, BATCH_SIZE = 256, 8
from torch.nn.functional import pad as tpad
def collate(batch):
    seqs, xyzs, embs = zip(*batch)
    L = min(MAX_LEN, max(len(s) for s in seqs))
    pad_s, pad_x, pad_e, masks = [], [], [], []
    for s, x, e in zip(seqs, xyzs, embs):
        s, x, e = s[:L], x[:L], e[:L]
        pad_s.append(tpad(s, (0, L-len(s))))
        pad_x.append(tpad(x, (0,0,0, L-len(x))))
        # ---- FP16 → FP32 to match Linear weights
        e32 = e.float()
        pad_e.append(tpad(e32, (0,0,0, L-len(e))))
        masks.append(torch.arange(L) < len(s))
    return (torch.stack(pad_s),
            torch.stack(pad_x),
            torch.stack(pad_e),
            torch.stack(masks))

train_loader = torch.utils.data.DataLoader(
    ProtEmb(NPZ_DIR), batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=collate
)

# ---- 5 · MiniFold-E model ----------------------------------
import torch.nn as nn
EMBED_DIM, N_LAYERS, LR, EPOCHS = 256, 8, 3e-4, 50

class MiniFoldE(nn.Module):
    def __init__(self, vocab=21):
        super().__init__()
        self.embed  = nn.Embedding(vocab, EMBED_DIM, padding_idx=0)
        self.projE  = nn.Linear(1280, EMBED_DIM, bias=False)
        enc = nn.TransformerEncoderLayer(
            EMBED_DIM, nhead=4, dim_feedforward=EMBED_DIM*4,
            batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=N_LAYERS)
        self.coord_head = nn.Linear(EMBED_DIM, 3)
    def forward(self, seq, emb, mask):
        x = self.embed(seq) + self.projE(emb)   # dtype matches (FP32)
        x = self.encoder(x, src_key_padding_mask=~mask)
        return self.coord_head(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MiniFoldE().to(device)
opt    = torch.optim.Adam(model.parameters(), lr=LR)

# ---- 6 · Training loop -------------------------------------
def pairwise(x): return torch.cdist(x, x)        # quick distances
for ep in range(1, EPOCHS+1):
    model.train(); run = 0
    for s,x,e,m in train_loader:
        s,x,e,m = s.to(device),x.to(device),e.to(device),m.to(device)
        pred = model(s,e,m)
        loss = ((pairwise(pred)-pairwise(x))[m[:,:,None] & m[:,None,:]]**2).mean()
        opt.zero_grad(); loss.backward(); opt.step()
        run += loss.item()
    print(f"Epoch {ep:2d}/{EPOCHS}  loss = {run/len(train_loader):.2f}")

print("🏁 MiniFold-E training done — ready for torsion/side-chain head next.")


Embedding:   0%|          | 0/164 [00:00<?, ?it/s]

✅ ESM embeddings cached in NPZs
Epoch  1/50  loss = 423.80
Epoch  2/50  loss = 289.03
Epoch  3/50  loss = 255.99
Epoch  4/50  loss = 238.98
Epoch  5/50  loss = 223.87
Epoch  6/50  loss = 218.63
Epoch  7/50  loss = 197.89
Epoch  8/50  loss = 184.49
Epoch  9/50  loss = 183.68
Epoch 10/50  loss = 180.99
Epoch 11/50  loss = 172.89
Epoch 12/50  loss = 168.78
Epoch 13/50  loss = 167.17
Epoch 14/50  loss = 162.64
Epoch 15/50  loss = 160.96
Epoch 16/50  loss = 158.74
Epoch 17/50  loss = 155.96
Epoch 18/50  loss = 149.87
Epoch 19/50  loss = 141.21
Epoch 20/50  loss = 135.74
Epoch 21/50  loss = 129.12
Epoch 22/50  loss = 129.72
Epoch 23/50  loss = 120.24
Epoch 24/50  loss = 112.79
Epoch 25/50  loss = 112.12
Epoch 26/50  loss = 107.93
Epoch 27/50  loss = 105.63
Epoch 28/50  loss = 100.78
Epoch 29/50  loss = 98.59
Epoch 30/50  loss = 93.61
Epoch 31/50  loss = 92.73
Epoch 32/50  loss = 95.10
Epoch 33/50  loss = 95.72
Epoch 34/50  loss = 91.85
Epoch 35/50  loss = 89.34
Epoch 36/50  loss = 89.07
Epoc

In [32]:
# ------------------------------------------------------------
#  Quick C-α RMSD check (device-safe)
# ------------------------------------------------------------
import torch

def kabsch_align(P, Q):
    Pc, Qc = P - P.mean(0), Q - Q.mean(0)
    C = Pc.T @ Qc
    V, _, Wt = torch.linalg.svd(C)
    R = V @ Wt
    if torch.det(R) < 0:                       # right-handed
        V[:, -1] *= -1; R = V @ Wt
    return Pc @ R + Q.mean(0)

model.eval()
with torch.no_grad():
    seq, xyz_gt, emb, mask = next(iter(train_loader))
    seq, emb, mask = seq.to(device), emb.to(device), mask.to(device)

    pred_xyz = model(seq, emb, mask).cpu()[0]   # → CPU
    gt_xyz   = xyz_gt[0]                        # already CPU
    m_bool   = mask[0].cpu()                    # <─ fix: move to CPU

    aligned = kabsch_align(pred_xyz[m_bool], gt_xyz[m_bool])
    rmsd = torch.sqrt(((aligned - gt_xyz[m_bool])**2).mean()).item()
    print(f"Aligned C-α RMSD: {rmsd:.2f} Å")

Aligned C-α RMSD: 7.36 Å


In [33]:
# ============================================
#  A · enrich NPZs with N, C coords + phi/psi
# ============================================
!pip -q install biopython

import numpy as np, math
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB import is_aa
from pathlib import Path
from tqdm.auto import tqdm

DATA = Path("/content/protein_data")
NPZ  = DATA / "npz"
CIF  = {p.stem[:4]: p for p in DATA.glob("*.cif")}   # from Stage-1
parser = MMCIFParser(QUIET=True)

def dih(p,q,r,s):
    b0,b1,b2=q-p,r-q,s-r
    n1,n2=np.cross(b0,b1),np.cross(b1,b2)
    x=(n1*n2).sum(); y=(np.cross(n1,n2)*b1/np.linalg.norm(b1)).sum()
    return math.atan2(y,x)

for npz in tqdm(list(NPZ.glob("*.npz")), desc="torsion"):
    uid = npz.stem; pdb, chain = uid[:4], uid[5]
    if "phi" in np.load(npz): continue          # already enriched
    model = parser.get_structure(pdb, CIF[pdb])[0]
    ch    = next(c for c in model if c.id.strip()==chain.strip())
    N,C,phi,psi,CA,prevC=[],[],[],[],[],None
    for res in ch:
        if not is_aa(res): continue
        if not all(a in res for a in ("N","CA","C")): continue
        N.append(res["N"].coord); CA.append(res["CA"].coord); C.append(res["C"].coord)
        if prevC is not None:
            phi.append(dih(prevC, res["N"].coord, res["CA"].coord, res["C"].coord))
        prevC = res["C"].coord
    phi = [0.0]+phi;                                    # pad ends
    for i in range(len(N)-1):
        psi.append(dih(N[i], CA[i], C[i], N[i+1]))
    psi.append(0.0)
    np.savez_compressed(npz, **np.load(npz), N=np.array(N), C=np.array(C),
                        phi=np.array(phi), psi=np.array(psi))
print("✅ NPZs updated with backbone atoms and torsions")


torsion:   0%|          | 0/164 [00:00<?, ?it/s]

✅ NPZs updated with backbone atoms and torsions


In [34]:
# ============================================
#  B · model with frame + torsion heads
# ============================================

import torch, torch.nn as nn, math
from pathlib import Path
import numpy as np

EMBED, NLAY = 256, 8         # same as before

def axis2rot(v):
    θ = torch.linalg.norm(v, dim=-1, keepdim=True)+1e-8
    k = v/θ
    K = torch.zeros((*v.shape[:-1],3,3), device=v.device)
    K[...,0,1],K[...,1,0]= -k[...,2], k[...,2]
    K[...,0,2],K[...,2,0]=  k[...,1],-k[...,1]
    K[...,1,2],K[...,2,1]= -k[...,0], k[...,0]
    I = torch.eye(3, device=v.device)
    return I + torch.sin(θ)[...,None]*K + (1-torch.cos(θ))[...,None]*(K@K)

class MiniFoldAll(nn.Module):
    def __init__(self, vocab=21):
        super().__init__()
        self.embed  = nn.Embedding(vocab, EMBED, padding_idx=0)
        self.projE  = nn.Linear(1280, EMBED, bias=False)
        enc = nn.TransformerEncoderLayer(EMBED,4,EMBED*4,batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=NLAY)
        self.frame_head   = nn.Linear(EMBED, 6)    # 3 t  + 3 axis-angle
        self.tors_head    = nn.Linear(EMBED, 4)    # sin/cos φ, ψ
    def forward(self, seq, emb, mask):
        h = self.embed(seq) + self.projE(emb)
        h = self.encoder(h, src_key_padding_mask=~mask)
        T = self.frame_head(h)                     # [B,L,6]
        rot = axis2rot(T[...,3:])                  # (B,L,3,3)
        trans = T[...,:3]                          # (B,L,3)
        # backbone atoms in local frame (Å)
        N_off  = torch.tensor([-1.458,  0.000, 0.000], device=h.device)
        CA_off = torch.tensor([ 0.000,  0.000, 0.000], device=h.device)
        C_off  = torch.tensor([ 1.525, -0.525, 0.000], device=h.device)
        N  = (rot @ N_off)  + trans
        CA = (rot @ CA_off) + trans
        C  = (rot @ C_off)  + trans
        tors = self.tors_head(h)                  # [B,L,4]  sinϕ cosϕ sinψ cosψ
        return {"N":N,"CA":CA,"C":C,"tors":tors,"mask":mask}


In [39]:
# ============================================
#  C · fine-tune MiniFold-AllAtom  (dtype-safe)
# ============================================

from torch.nn.functional import mse_loss, pad as tpad
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch, numpy as np
from pathlib import Path

# ----- dataset ----------------------------------------------
class ProtFull(torch.utils.data.Dataset):
    AA = "ACDEFGHIKLMNPQRSTVWY"; aa2i = {a:i+1 for i,a in enumerate(AA)}
    def __init__(self, folder): self.f = list(Path(folder).glob("*.npz"))
    def __len__(self): return len(self.f)
    def __getitem__(s, i):
        d = np.load(s.f[i])
        aa = [ProtFull.aa2i.get(c.decode(), 0) for c in d["seq"]]
        return (torch.tensor(aa,                 dtype=torch.long),
                torch.tensor(d["coords"],       dtype=torch.float32),
                torch.tensor(d["N"],            dtype=torch.float32),
                torch.tensor(d["C"],            dtype=torch.float32),
                torch.tensor(d["phi"],          dtype=torch.float32),
                torch.tensor(d["psi"],          dtype=torch.float32),
                torch.tensor(d["emb"],          dtype=torch.float32))  # cast here

# ----- collate ----------------------------------------------
MAX_LEN, BATCH = 256, 8
def pad1(t,L): return tpad(t, (0, L-len(t)))
def pad2(t,L): return tpad(t, (0,0,0, L-len(t)))

def collate(batch):
    seq, ca, N, C, phi, psi, emb = zip(*batch)
    L = min(MAX_LEN, max(len(s) for s in seq))
    out = [
        torch.stack([pad1(s[:L], L) for s in seq]),
        torch.stack([pad2(x[:L], L) for x in ca]),
        torch.stack([pad2(x[:L], L) for x in N]),
        torch.stack([pad2(x[:L], L) for x in C]),
        torch.stack([pad1(t[:L], L) for t in phi]),
        torch.stack([pad1(t[:L], L) for t in psi]),
        torch.stack([pad2(e[:L], L) for e in emb]),
    ]
    mask = torch.stack([torch.arange(L) < len(s) for s in seq])
    return (*out, mask)

loader = DataLoader(
    ProtFull("/content/protein_data/npz"),
    batch_size=BATCH, shuffle=True, collate_fn=collate)

# ----- fine-tune --------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = MiniFoldAll().to(device)
opt    = torch.optim.Adam(model.parameters(), lr=3e-4)

def fape(pred, gt, mask):
    diff = (pred - gt)[mask]
    return torch.sqrt((diff**2).sum(-1) + 1e-8).mean()

EPOCHS = 20
for ep in range(1, EPOCHS+1):
    cum = 0
    for s, ca_gt, N_gt, C_gt, phi_gt, psi_gt, emb, m in loader:
        s, emb, m = s.to(device), emb.to(device), m.to(device)
        out = model(s, emb, m)
        m3 = m.unsqueeze(-1).expand(-1, -1, 3)

        phi = phi_gt.to(device)
        psi = psi_gt.to(device)

        loss = (
            fape(out["N"],  N_gt.to(device), m3) +
            fape(out["C"],  C_gt.to(device), m3) +
            mse_loss(out["tors"][...,0][m], torch.sin(phi)[m]) +
            mse_loss(out["tors"][...,1][m], torch.cos(phi)[m]) +
            mse_loss(out["tors"][...,2][m], torch.sin(psi)[m]) +
            mse_loss(out["tors"][...,3][m], torch.cos(psi)[m])
        )
        opt.zero_grad(); loss.backward(); opt.step()
        cum += loss.item()
    print(f"Epoch {ep}/{EPOCHS}  loss = {cum/len(loader):.2f}")

print("🏁 Fine-tuning complete — backbone frames and φ/ψ torsions are now learned.")


Epoch 1/20  loss = 5150.89
Epoch 2/20  loss = 4915.58
Epoch 3/20  loss = 4905.28
Epoch 4/20  loss = 4855.53
Epoch 5/20  loss = 4777.72
Epoch 6/20  loss = 4705.46
Epoch 7/20  loss = 4605.28
Epoch 8/20  loss = 4488.41
Epoch 9/20  loss = 4497.80
Epoch 10/20  loss = 4428.46
Epoch 11/20  loss = 4444.48
Epoch 12/20  loss = 4376.16
Epoch 13/20  loss = 4383.30
Epoch 14/20  loss = 4438.85
Epoch 15/20  loss = 4357.11
Epoch 16/20  loss = 4300.79
Epoch 17/20  loss = 4338.42
Epoch 18/20  loss = 4364.30
Epoch 19/20  loss = 4308.06
Epoch 20/20  loss = 4334.27
🏁 Fine-tuning complete — backbone frames and φ/ψ torsions are now learned.


In [55]:
# --------------------------------------------
#  Proper fine-tune: correct mask broadcast
#  + bond-length prior    (run 15-20 epochs)
# --------------------------------------------
def fape(pred, gt, mask_xyz):              # mask_xyz shape B,L,3
    diff = (pred - gt)[mask_xyz]
    return torch.sqrt((diff**2).sum(-1) + 1e-8).mean()

def bond_len(ca, mask, target=3.8):
    d = torch.linalg.norm(ca[:,1:] - ca[:,:-1], dim=-1)   # (B,L-1)
    m = mask[:,1:] & mask[:,:-1]
    return ((d[m] - target)**2).mean()

opt = torch.optim.Adam(model.parameters(), lr=5e-5)
EPOCHS = 20
for ep in range(1, EPOCHS+1):
    tot = 0
    for s, ca_gt, N_gt, C_gt, phi_gt, psi_gt, emb, m in loader:
        s, emb, m = s.to(device), emb.to(device), m.to(device)
        out = model(s, emb, m)
        m3 = m.unsqueeze(-1).expand(-1, -1, 3)            # broadcast to xyz
        loss = (
            fape(out["N"] ,  N_gt .to(device), m3) +
            fape(out["CA"], ca_gt.to(device), m3) +        # now contributes
            fape(out["C"] ,  C_gt .to(device), m3) +
            bond_len(out["CA"], m) +                       # correct mask
            mse_loss(out["tors"][...,0][m], torch.sin(phi_gt).to(device)[m]) +
            mse_loss(out["tors"][...,1][m], torch.cos(phi_gt).to(device)[m]) +
            mse_loss(out["tors"][...,2][m], torch.sin(psi_gt).to(device)[m]) +
            mse_loss(out["tors"][...,3][m], torch.cos(psi_gt).to(device)[m])
        )
        opt.zero_grad(); loss.backward(); opt.step()
        tot += loss.item()
    print(f"Epoch {ep:2d}/{EPOCHS}  loss = {tot/len(loader):.1f}")




Epoch  1/20  loss = 5696.1
Epoch  2/20  loss = 5552.9
Epoch  3/20  loss = 5673.5
Epoch  4/20  loss = 5540.4
Epoch  5/20  loss = 5541.1
Epoch  6/20  loss = 5634.9
Epoch  7/20  loss = 5672.2
Epoch  8/20  loss = 5552.7
Epoch  9/20  loss = 5624.9
Epoch 10/20  loss = 5598.6
Epoch 11/20  loss = 5491.9
Epoch 12/20  loss = 5592.5
Epoch 13/20  loss = 5451.4
Epoch 14/20  loss = 5562.3
Epoch 15/20  loss = 5445.4
Epoch 16/20  loss = 5424.6
Epoch 17/20  loss = 5370.4
Epoch 18/20  loss = 5427.4
Epoch 19/20  loss = 5429.3
Epoch 20/20  loss = 5462.6


In [56]:
# ============================================================
#  Inference helper (unique serials, full backbone viewed)
# ============================================================

!pip -q install py3Dmol

import torch, py3Dmol, math
from pathlib import Path, PurePosixPath
from datetime import datetime

AA  = "ACDEFGHIKLMNPQRSTVWY"; aa2i = {a:i+1 for i,a in enumerate(AA)}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seq_to_ids(seq): return torch.tensor([aa2i.get(a,0) for a in seq], dtype=torch.long)

def predict_backbone(seq):
    model.eval()
    with torch.no_grad():
        ids  = seq_to_ids(seq).unsqueeze(0).to(device)
        emb  = seq_to_emb(seq).unsqueeze(0).to(torch.float32).to(device)
        mask = torch.ones_like(ids, dtype=torch.bool).to(device)
        out  = model(ids, emb, mask)
    return {k: v.squeeze(0).cpu().numpy() for k,v in out.items()}

def save_pdb(bb, seq, fname):
    lines=[]
    serial = 1
    for i,(n,ca,c) in enumerate(zip(bb["N"], bb["CA"], bb["C"]), 1):
        res = seq[i-1]
        r   = i
        lines.append(f"ATOM  {serial:5d}  N   {res:<3s} A{r:4d}    "
                     f"{n[0]:8.3f}{n[1]:8.3f}{n[2]:8.3f}  1.00  0.00           N")
        serial += 1
        lines.append(f"ATOM  {serial:5d}  CA  {res:<3s} A{r:4d}    "
                     f"{ca[0]:8.3f}{ca[1]:8.3f}{ca[2]:8.3f}  1.00  0.00           C")
        serial += 1
        lines.append(f"ATOM  {serial:5d}  C   {res:<3s} A{r:4d}    "
                     f"{c[0]:8.3f}{c[1]:8.3f}{c[2]:8.3f}  1.00  0.00           C")
        serial += 1
    Path(fname).write_text("\n".join(lines) + "\nEND\n")

# --- demo ---------------------------------------------------
test_seq = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"   # GFP fragment
bb   = predict_backbone(test_seq)
tag  = datetime.now().strftime("%Y%m%d_%H%M%S")
pdbf = f"/content/minifoldE_pred_{tag}.pdb"
save_pdb(bb, test_seq, pdbf)
print("PDB saved →", pdbf)

# --- 3-D view ----------------------------------------------
view = py3Dmol.view(width=450, height=320)
view.addModel(Path(pdbf).read_text(), 'pdb')
view.setStyle({'cartoon': {'color': 'spectrum'}})
view.zoomTo(); view.show()


PDB saved → /content/minifoldE_pred_20250429_041540.pdb


In [57]:
from google.colab import files
files.download(pdbf)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [58]:
# ============================================
#  Quick sanity check – C-α coordinates
# ============================================
seq = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"
bb  = predict_backbone(seq)          # uses the helper from previous cell
ca  = bb["CA"]                       # (L,3) numpy

print("First 10 C-α coordinates")
for i,xyz in enumerate(ca[:10],1):
    print(f"{i:2d}  {xyz}")

# pairwise distance between CA1 and CAi
print("\nDistance from residue 1 to first 10 residues")
d = np.linalg.norm(ca - ca[0], axis=1)
for i,val in enumerate(d[:10],1):
    print(f"1 → {i:2d}   {val:6.2f} Å")


First 10 C-α coordinates
 1  [-0.05270198 -1.8775257   1.7340031 ]
 2  [-0.05868143 -2.130478    1.6887217 ]
 3  [ 0.4473447 -2.1065643  2.1198704]
 4  [-0.02386281 -0.68777657  2.128097  ]
 5  [-0.35148764 -2.0692577   1.0876955 ]
 6  [-0.3409083 -2.0812936  1.1020449]
 7  [-0.43363196 -1.8169308   0.9465486 ]
 8  [-0.31809232 -1.5992271   1.1880381 ]
 9  [ 0.2443927 -1.3419114  2.5121348]
10  [ 0.06442496 -0.6862275   2.2712295 ]

Distance from residue 1 to first 10 residues
1 →  1     0.00 Å
1 →  2     0.26 Å
1 →  3     0.67 Å
1 →  4     1.25 Å
1 →  5     0.74 Å
1 →  6     0.72 Å
1 →  7     0.88 Å
1 →  8     0.67 Å
1 →  9     0.99 Å
1 → 10     1.31 Å


In [60]:
if ep == 1 and tot == 0:     # first batch first epoch
    print("DEBUG  N-FAPE :", fape(out["N"] , N_gt.to(device), m3).item())
    print("DEBUG  CA-FAPE:", fape(out["CA"], ca_gt.to(device), m3).item())
    print("DEBUG  C-FAPE :", fape(out["C"] , C_gt.to(device), m3).item())
    print("DEBUG  bond   :", bond_len(out["CA"], m).item())
