
# 🧪 TorchMD (Colab) — Hardened Protein MD Demo + Visualizer

This notebook sets up a **stable** coarse‑grained Cα simulation and an **inline viewer**:
- Installs: TorchMD, MoleculeKit, MDTraj, NGLView, py3Dmol
- Downloads a small PDB (default: **1CRN**)
- Builds a **Cα-only** model
- Runs a **rock‑solid**, damped MD‑like loop in **PyTorch** (velocity‑Verlet + Langevin) with **minimization**, exclusions, and safety clamps
- Saves a **DCD** trajectory (no PDB field-width issues)
- Visualizes with **nglview**, with **py3Dmol** fallback

> Didactic parameters are conservative to avoid blow‑ups; you can loosen them once the visual works.


In [None]:

#@title 📦 Install & Imports (Colab-safe)
!pip -q install torchmd moleculekit mdtraj biopython nglview py3Dmol ipywidgets==7.7.1

# Enable widgets (for NGLView) in Colab
from google.colab import output
output.enable_custom_widget_manager()

import sys, os, glob, shutil, math, random
import numpy as np
import torch
import mdtraj as md
import nglview as nv
from Bio.PDB import PDBList

print("Python:", sys.version.split()[0])
print("PyTorch:", torch.__version__)
print("TorchMD:", __import__("torchmd").__version__)
print("MDTraj:", md.__version__)
print("nglview:", nv.__version__)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.9/21.9 MB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m97.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━



Python: 3.12.11
PyTorch: 2.8.0+cu126
TorchMD: 1.0.4
MDTraj: 1.11.0
nglview: 3.0.8


In [None]:

#@title 📥 Download a protein from the PDB
pdb_id = "2ETL"  #@param {type:"string"}
save_as = "protein.pdb"

pdbl = PDBList()
pdbl.retrieve_pdb_file(pdb_id, file_format="pdb", pdir=".")
matches = glob.glob(f"**/*{pdb_id.lower()}.ent", recursive=True)
assert matches, "PDB download failed; try a different ID."
shutil.copyfile(matches[0], save_as)
print("Saved", save_as)


Downloading PDB structure '2etl'...
Saved protein.pdb


In [None]:

#@title 🧹 Clean PDB to protein only (MoleculeKit)
from moleculekit.molecule import Molecule
mol = Molecule("protein.pdb")
mol.filter("protein")
mol.write("protein_clean.pdb")
print("Wrote protein_clean.pdb")


2025-08-25 16:26:30,160 - moleculekit.molecule - INFO - Removed 111 atoms. 3482 atoms remaining in the molecule.


Wrote protein_clean.pdb


In [None]:

#@title 🧩 Build Cα-only model (positions, bonds, masses)
import mdtraj as md
import numpy as np

t = md.load_pdb("protein_clean.pdb")
top = t.topology
ca_indices = top.select("name CA")
assert ca_indices.size > 3, "No Cα atoms found. Try another PDB."
t_ca = t.atom_slice(ca_indices)
t_ca.save_pdb("protein_ca.pdb")
print("Cα beads:", t_ca.n_atoms)

# Bonds between consecutive residues (ignore chain breaks)
bonds = []
prev_res = None
for atom in t_ca.topology.atoms:
    if atom.name == "CA":
        if prev_res is not None and atom.residue.index == prev_res + 1:
            bonds.append((atom.index, atom.index - 1))  # connect sequential
        prev_res = atom.residue.index
bonds = np.array(bonds, dtype=int)
print("Cα–Cα bonds:", bonds.shape[0])

# Rough per-bead mass (g/mol). Not used for units rigor here.
masses = np.full((t_ca.n_atoms,), 110.0, dtype=np.float32)

np.savez("ca_model_inputs.npz", bonds=bonds, masses=masses, ca_indices=ca_indices)


Cα beads: 446
Cα–Cα bonds: 445


In [None]:

#@title ⚙️ Rock‑solid coarse‑grained MD (no NaNs) — writes traj_ca.npy
import torch, numpy as np, mdtraj as md, math, random

# Reproducibility
torch.manual_seed(42); np.random.seed(42); random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

ref   = md.load_pdb("protein_ca.pdb")
pos0  = torch.tensor(ref.xyz[0], dtype=torch.float32, device=device)  # Å, (N,3)

dat    = np.load("ca_model_inputs.npz")
bonds  = torch.tensor(dat["bonds"],  dtype=torch.long,   device=device)
masses = torch.tensor(dat["masses"], dtype=torch.float32, device=device)
N      = pos0.shape[0]
mass   = masses.unsqueeze(1)

# ---- Very conservative parameters ----
dt           = 0.00025      # ps (0.25 fs)
n_steps      = 3000         # short, safe
temperature  = 300.0        # K
gamma        = 8.0          # ps^-1 (more damping)
save_interval= 10
k_bond       = 5.0          # kcal/(mol·Å^2)
epsilon      = 0.001        # kcal/mol (tiny, to avoid big kicks)
sigma        = 3.8          # Å
cutoff       = 6.0          # Å
min_dist     = 0.8          # Å (floor for pair distances)
Fmax         = 50.0         # kcal/(mol·Å) force clip
Vmax         = 20.0         # Å/ps  velocity clip
box_limit    = 100.0        # Å position clamp

# Equilibrium bond lengths from starting coords
with torch.no_grad():
    r0 = torch.linalg.norm(pos0[bonds[:,0]] - pos0[bonds[:,1]], dim=1)

# State
pos = pos0.clone()
vel = torch.zeros_like(pos)

# Constants
kB = 0.0019872041  # kcal/(mol·K)
sqrt_2gkT_dt = math.sqrt(2.0 * gamma * kB * temperature * dt)

# Build an index-distance matrix once to exclude 1-2 & 1-3 for LJ
idx = torch.arange(N, device=device)
I, J = torch.triu_indices(N, N, offset=1, device=device)  # all i<j pairs
# sequence distance using CA order (assumes CA atoms are indexed sequentially)
seq_dist = (J - I)
lj_exclude = (seq_dist <= 2)  # mask to EXCLUDE for LJ

def bonded_forces(p):
    v = p[bonds[:,0]] - p[bonds[:,1]]
    dist = torch.linalg.norm(v, dim=1) + 1e-12
    coeff = -k_bond * (dist - r0) / dist
    f = torch.zeros_like(p)
    f.index_add_(0, bonds[:,0], coeff.unsqueeze(1) * v)
    f.index_add_(0, bonds[:,1], -coeff.unsqueeze(1) * v)
    return f

def lj_forces(p):
    # Compute pair vectors/ distances for all i<j pairs
    rij = p[J] - p[I]                               # (#pairs, 3)
    dist = torch.linalg.norm(rij, dim=1) + 1e-12
    # Exclusions + cutoffs + minimum distance floor
    valid = (~lj_exclude) & (dist < cutoff)
    if valid.any():
        rijv  = rij[valid]
        distv = torch.clamp(dist[valid], min=min_dist)
        inv   = (sigma / distv)
        inv6  = inv**6
        inv12 = inv6**2
        mag   = 24.0 * epsilon * (2.0*inv12 - inv6) / distv  # (M,)
        vec   = (mag.unsqueeze(1) * rijv / distv.unsqueeze(1))  # (M,3)

        # Scatter-add to atom forces (i gets -vec, j gets +vec)
        f = torch.zeros_like(p)
        i_idx = I[valid]
        j_idx = J[valid]
        # force clipping per-pair (componentwise)
        vec = torch.nan_to_num(vec, nan=0.0, posinf=0.0, neginf=0.0)
        vec = torch.clamp(vec, -Fmax, Fmax)
        f.index_add_(0, i_idx, -vec)
        f.index_add_(0, j_idx,  vec)
        return f
    else:
        return torch.zeros_like(p)

def clip_forces(F):
    # global clip to avoid spikes
    return torch.clamp(torch.nan_to_num(F, nan=0.0, posinf=0.0, neginf=0.0), -Fmax, Fmax)

# --- Bond-only minimization (gentle) ---
with torch.no_grad():
    for _ in range(200):
        Fb = bonded_forces(pos)
        Fb = clip_forces(Fb)
        pos += 0.0002 * Fb  # tiny gradient descent step
        pos = torch.clamp(pos, -box_limit, box_limit)

frames = []
for step in range(n_steps):
    # Forces
    Fb = bonded_forces(pos)
    Flj = lj_forces(pos)
    F = clip_forces(Fb + Flj)

    # Langevin noise
    rand = torch.randn_like(pos) * (sqrt_2gkT_dt) / torch.sqrt(mass)

    # Update (BAOAB-like)
    vel += 0.5 * dt * (F / mass - gamma * vel) + rand
    vel = torch.clamp(torch.nan_to_num(vel, nan=0.0), -Vmax, Vmax)

    pos += dt * vel
    pos = torch.clamp(torch.nan_to_num(pos, nan=0.0), -box_limit, box_limit)

    F_new = clip_forces(bonded_forces(pos) + lj_forces(pos))
    vel += 0.5 * dt * (F_new / mass - gamma * vel) + rand
    vel = torch.clamp(torch.nan_to_num(vel, nan=0.0), -Vmax, Vmax)

    if step % save_interval == 0:
        frames.append(pos.detach().cpu().numpy())

frames = np.array(frames, dtype=np.float32)  # (T,N,3) Å
# Drop any frames that contain NaN/Inf just in case
valid = np.isfinite(frames).all(axis=(1,2))
frames = frames[valid]
np.save("traj_ca.npy", frames)
print("Saved frames:", frames.shape, "| valid frames kept:", valid.sum())


Using device: cpu
Saved frames: (300, 446, 3) | valid frames kept: 300


In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:

#@title 💾 DCD + robust visualization (filters NaNs; py3Dmol fallback)
import numpy as np, mdtraj as md

frames = np.load("traj_ca.npy")
assert frames.ndim == 3 and frames.shape[2] == 3, "Trajectory array malformed."

# Filter again (defensive)
good_mask = np.isfinite(frames).all(axis=(1,2))
frames = frames[good_mask]
assert frames.shape[0] > 0, "No valid frames to visualize (all were NaN/Inf)."

xyz_nm = frames / 10.0  # Å → nm
ref = md.load_pdb("protein_ca.pdb")
traj = md.Trajectory(xyz_nm, ref.topology)
traj.save_dcd("traj_ca.dcd")
print("Wrote traj_ca.dcd with", traj.n_frames, "frames")

# Try NGLView first
try:
    import nglview as nv
    from IPython.display import display
    view = nv.show_mdtraj(traj)
    view.clear_representations()
    view.add_representation("cartoon", selection="protein", opacity=0.9)
    view.player.delay = 25
    display(view)
    print("Displayed with NGLView ✅")
except Exception as e:
    print("NGLView failed:", e)
    # Fallback: clamp coords and write multi-model PDB for py3Dmol
    clamp = np.clip(frames, -9999.0, 9999.0) / 10.0  # nm
    traj2 = md.Trajectory(clamp, ref.topology)
    mm_pdb = "traj_ca_models.pdb"
    traj2.save_pdb(mm_pdb)

    import py3Dmol
    with open(mm_pdb, "r") as fh:
        pdb_data = fh.read()
    v = py3Dmol.view(width=700, height=500)
    v.addModelsAsFrames(pdb_data, "pdb")
    v.setStyle({"cartoon": {"opacity": 0.9}})
    v.animate({"loop": "forward", "interval": 50})
    v.zoomTo()
    v.show()
    print("Displayed with py3Dmol ✅")

    # Enlarge the widget
view = nv.show_mdtraj(traj)
view._remote_call('setSize', target='Widget',
                  args=['1000px','800px'])   # width x height
view


Wrote traj_ca.dcd with 300 frames


NGLWidget(max_frame=299)

Displayed with NGLView ✅


NGLWidget(max_frame=299)