<a href="https://colab.research.google.com/github/vramonlinebsc/diffrential_relaxation/blob/main/ccr_1_iteration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!/usr/bin/env python3
"""
DiffRelax: Differentiable NMR Relaxation Engine for Structure Refinement
===========================================================================

Complete self-contained implementation for Google Colab.

PROJECT GOAL:
Build first JAX-based differentiable physics engine for NMR relaxation
that enables gradient-based structure refinement using R1, R2, NOE, and
cross-correlated relaxation (CCR) data.

NOVEL CONTRIBUTION:
- First differentiable CCR implementation
- Enables gradient-based structure refinement with CCR
- Shows when CCR is information-theoretically necessary

TARGET: JACS, J. Phys. Chem. B, or J. Chem. Info. Model.

USAGE IN GOOGLE COLAB:
    1. Upload this file or paste into notebook
    2. Run: %run diffrelax_complete.py
    3. Follow interactive prompts
    4. Results saved with automatic checkpoints

AUTHOR: Built with Claude (Anthropic)
DATE: 2026
"""

import os
import sys
import json
import pickle
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Check environment and install dependencies
def setup_environment():
    """Install all required packages"""
    print("="*70)
    print("DiffRelax: Setting up environment...")
    print("="*70)

    packages = [
        'jax[cuda12]',
        'optax',
        'equinox',
        'biopython',
        'requests',
        'pandas',
        'matplotlib',
        'seaborn'
    ]

    import subprocess
    for pkg in packages:
        print(f"Installing {pkg}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])

    print("‚úì All packages installed\n")

# Run setup
setup_environment()

# Now import everything
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import requests
from Bio.PDB import PDBParser

# ============================================================================
# PHYSICAL CONSTANTS
# ============================================================================

class Constants:
    """Physical constants for NMR relaxation calculations"""

    # Gyromagnetic ratios (rad/s/T)
    GAMMA_H = 2.6752e8   # 1H
    GAMMA_N = -2.713e7   # 15N
    GAMMA_C = 6.728e7    # 13C

    # Fundamental constants
    H_BAR = 1.054571817e-34  # J¬∑s
    MU_0 = 4 * np.pi * 1e-7  # T¬∑m/A

    # Typical values
    R_NH = 1.02e-10  # N-H bond length (m)
    DELTA_SIGMA_N = -160e-6  # 15N CSA (ppm ‚Üí dimensionless)

    # Magnetic field
    B0_600 = 14.1  # Tesla (600 MHz for 1H)

    @classmethod
    def omega(cls, nucleus='H', field=14.1):
        """Larmor frequency"""
        gamma = {'H': cls.GAMMA_H, 'N': cls.GAMMA_N, 'C': cls.GAMMA_C}
        return gamma[nucleus] * field

# ============================================================================
# CHECKPOINT SYSTEM
# ============================================================================

class CheckpointManager:
    """Manages saving/loading of intermediate results"""

    def __init__(self, base_dir='/content/diffrelax_checkpoints'):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True, parents=True)
        print(f"Checkpoint directory: {self.base_dir}")

    def save(self, data, name):
        """Save data with pickle"""
        path = self.base_dir / f"{name}.pkl"
        with open(path, 'wb') as f:
            pickle.dump(data, f)
        print(f"‚úì Saved checkpoint: {name}")

    def load(self, name):
        """Load data, return None if not found"""
        path = self.base_dir / f"{name}.pkl"
        if path.exists():
            with open(path, 'rb') as f:
                print(f"‚úì Loaded checkpoint: {name}")
                return pickle.load(f)
        return None

    def exists(self, name):
        """Check if checkpoint exists"""
        return (self.base_dir / f"{name}.pkl").exists()

# ============================================================================
# DATA FETCHING
# ============================================================================

class BMRBFetcher:
    """Fetch NMR relaxation data from BMRB database"""

    API_URL = "https://api.bmrb.io/v2"

    def __init__(self, checkpoint_mgr):
        self.checkpoint = checkpoint_mgr

    def fetch_entry(self, bmrb_id):
        """Download BMRB entry with caching"""
        cached = self.checkpoint.load(f'bmrb_{bmrb_id}')
        if cached:
            return cached

        print(f"Downloading BMRB {bmrb_id}...")
        url = f"{self.API_URL}/entry/{bmrb_id}"
        response = requests.get(url, timeout=30)

        if response.status_code == 200:
            data = response.json()
            self.checkpoint.save(data, f'bmrb_{bmrb_id}')
            return data
        else:
            print(f"‚úó Failed: HTTP {response.status_code}")
            return None

    def extract_relaxation(self, entry_data):
        """Extract R1, R2, NOE from BMRB entry"""
        all_data = []

        def safe_get(d, keys, default=None):
            for key in keys:
                if isinstance(d, dict) and key in d:
                    d = d[key]
                else:
                    return default
            return d

        # Extract T1, T2, NOE
        for relax_type, bmrb_key in [
            ('T1', 'heteronucl_T1_relaxation'),
            ('T2', 'heteronucl_T2_relaxation'),
            ('NOE', 'heteronucl_NOEs')
        ]:
            data = safe_get(entry_data, [bmrb_key], [])
            if isinstance(data, list):
                for saveframe in data:
                    if 'data' in saveframe:
                        for row in saveframe['data']:
                            try:
                                all_data.append({
                                    'type': relax_type,
                                    'residue': int(row.get('Comp_index_ID', 0)),
                                    'atom': row.get('Atom_ID', 'N'),
                                    'value': float(row.get('Val', np.nan)),
                                    'error': float(row.get('Val_err', 0.1))
                                })
                            except (ValueError, TypeError):
                                continue

        df = pd.DataFrame(all_data)
        return df if len(df) > 0 else None

    def get_pdb_id(self, entry_data):
        """Extract associated PDB ID"""
        related = entry_data.get('related_entries', [])
        for entry in related:
            if entry.get('Database_name') == 'PDB':
                return entry.get('Database_accession_code')
        return None

    def download_pdb(self, pdb_id):
        """Download PDB structure"""
        cached = self.checkpoint.load(f'pdb_{pdb_id}')
        if cached:
            return cached

        url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
        print(f"Downloading PDB {pdb_id}...")
        response = requests.get(url, timeout=30)

        if response.status_code == 200:
            pdb_path = self.checkpoint.base_dir / f"{pdb_id}.pdb"
            with open(pdb_path, 'w') as f:
                f.write(response.text)
            self.checkpoint.save(str(pdb_path), f'pdb_{pdb_id}')
            return str(pdb_path)
        return None

# ============================================================================
# STRUCTURE HANDLING
# ============================================================================

@dataclass
class SpinSystem:
    """Container for spin system geometry (JAX arrays)"""
    N_coords: jnp.ndarray  # (n, 3)
    H_coords: jnp.ndarray  # (n, 3)
    CA_coords: jnp.ndarray  # (n, 3)
    residue_ids: np.ndarray  # (n,)

    def __len__(self):
        return len(self.N_coords)

def load_structure(pdb_file):
    """Load protein structure and extract N-H spin pairs"""
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    model = structure[0]

    N_coords, H_coords, CA_coords, residue_ids = [], [], [], []

    for residue in model.get_residues():
        if residue.has_id('N') and residue.has_id('H') and residue.has_id('CA'):
            N_coords.append(residue['N'].get_coord())
            H_coords.append(residue['H'].get_coord())
            CA_coords.append(residue['CA'].get_coord())
            residue_ids.append(residue.get_id()[1])

    return SpinSystem(
        N_coords=jnp.array(N_coords),
        H_coords=jnp.array(H_coords),
        CA_coords=jnp.array(CA_coords),
        residue_ids=np.array(residue_ids)
    )

# ============================================================================
# GEOMETRY CALCULATIONS (DIFFERENTIABLE)
# ============================================================================

@jit
def compute_nh_vectors(N_coords, H_coords):
    """Compute normalized N-H bond vectors"""
    vectors = H_coords - N_coords
    norms = jnp.linalg.norm(vectors, axis=1, keepdims=True)
    return vectors / norms

@jit
def compute_nh_distances(N_coords, H_coords):
    """Compute N-H bond lengths"""
    return jnp.linalg.norm(H_coords - N_coords, axis=1)

@jit
def compute_angles(v1, v2):
    """Angle between unit vectors"""
    cos_theta = jnp.sum(v1 * v2, axis=1)
    cos_theta = jnp.clip(cos_theta, -1.0, 1.0)
    return jnp.arccos(cos_theta)

@jit
def estimate_csa_axis(nh_vector, ca_coord, n_coord):
    """
    Estimate CSA principal axis orientation
    CSA tensor roughly perpendicular to peptide plane
    """
    ca_n = n_coord - ca_coord
    ca_n_norm = ca_n / jnp.linalg.norm(ca_n)
    csa_axis = jnp.cross(nh_vector, ca_n_norm)
    return csa_axis / jnp.linalg.norm(csa_axis)

compute_csa_axes = vmap(estimate_csa_axis, in_axes=(0, 0, 0))

# ============================================================================
# RELAXATION THEORY (DIFFERENTIABLE)
# ============================================================================

@jit
def spectral_density(omega, tau_c, S2=1.0):
    """
    Lorentzian spectral density with model-free parameters
    J(œâ) = (2/5) * S¬≤ * œÑc / (1 + œâ¬≤œÑc¬≤)
    """
    return (2.0/5.0) * S2 * tau_c / (1.0 + (omega * tau_c)**2)

@jit
def R1_dipolar(r_NH, tau_c, S2, omega_H, omega_N):
    """
    Dipolar R1 relaxation rate
    R1 = (d¬≤/4) [J(œâH-œâN) + 3J(œâN) + 6J(œâH+œâN)]
    """
    # Dipolar coupling constant
    d_squared = ((Constants.MU_0/(4*jnp.pi)) * Constants.GAMMA_H *
                 abs(Constants.GAMMA_N) * Constants.H_BAR / r_NH**3)**2

    J_diff = spectral_density(omega_H - omega_N, tau_c, S2)
    J_N = spectral_density(omega_N, tau_c, S2)
    J_sum = spectral_density(omega_H + omega_N, tau_c, S2)

    return (d_squared/4.0) * (J_diff + 3*J_N + 6*J_sum)

@jit
def R2_dipolar(r_NH, tau_c, S2, omega_H, omega_N):
    """
    Dipolar R2 relaxation rate
    R2 = (d¬≤/8) [4J(0) + J(œâH-œâN) + 3J(œâN) + 6J(œâH) + 6J(œâH+œâN)]
    """
    d_squared = ((Constants.MU_0/(4*jnp.pi)) * Constants.GAMMA_H *
                 abs(Constants.GAMMA_N) * Constants.H_BAR / r_NH**3)**2

    J_0 = (2.0/5.0) * S2 * tau_c
    J_diff = spectral_density(omega_H - omega_N, tau_c, S2)
    J_N = spectral_density(omega_N, tau_c, S2)
    J_H = spectral_density(omega_H, tau_c, S2)
    J_sum = spectral_density(omega_H + omega_N, tau_c, S2)

    return (d_squared/8.0) * (4*J_0 + J_diff + 3*J_N + 6*J_H + 6*J_sum)

@jit
def NOE(r_NH, tau_c, S2, omega_H, omega_N):
    """
    Heteronuclear NOE
    NOE = 1 + (Œ≥H/Œ≥N) * (d¬≤/4R1) * [6J(œâH+œâN) - J(œâH-œâN)]
    """
    d_squared = ((Constants.MU_0/(4*jnp.pi)) * Constants.GAMMA_H *
                 abs(Constants.GAMMA_N) * Constants.H_BAR / r_NH**3)**2

    R1_val = R1_dipolar(r_NH, tau_c, S2, omega_H, omega_N)

    J_diff = spectral_density(omega_H - omega_N, tau_c, S2)
    J_sum = spectral_density(omega_H + omega_N, tau_c, S2)

    return 1.0 + (Constants.GAMMA_H/abs(Constants.GAMMA_N)) * \
           (d_squared/(4*R1_val)) * (6*J_sum - J_diff)

@jit
def CCR_DD_CSA(r_NH, tau_c, S2, omega_N, theta, delta_sigma):
    """
    Cross-correlated relaxation: Dipole-Dipole vs CSA

    THIS IS THE NOVEL CONTRIBUTION

    Œ∑ = (c¬∑d/4) [4J(0) + 3J(œâN)]

    where:
        d = dipolar coupling
        c = CSA coupling = (œâN¬∑ŒîœÉ¬∑P2(cosŒ∏)) / ‚àö6
        P2(x) = (3x¬≤-1)/2
        Œ∏ = angle between NH and CSA principal axis
    """
    # Dipolar coupling
    d = (Constants.MU_0/(4*jnp.pi)) * Constants.GAMMA_H * \
        abs(Constants.GAMMA_N) * Constants.H_BAR / r_NH**3

    # CSA coupling with P2 Legendre polynomial
    P2_cos_theta = (3*jnp.cos(theta)**2 - 1) / 2
    c = (omega_N * delta_sigma * P2_cos_theta) / jnp.sqrt(6)

    # Spectral densities
    J_0 = (2.0/5.0) * S2 * tau_c
    J_N = spectral_density(omega_N, tau_c, S2)

    # CCR rate
    return (c * d / 4.0) * (4*J_0 + 3*J_N)

# ============================================================================
# FORWARD MODEL
# ============================================================================

class RelaxationPredictor:
    """Predict all relaxation rates from structure and dynamics"""

    def __init__(self, spin_system, field=14.1):
        self.spin_system = spin_system
        self.omega_H = Constants.omega('H', field)
        self.omega_N = Constants.omega('N', field)

    @jit
    def predict_all(self, tau_c, S2_array, delta_sigma=-160e-6):
        """
        Predict R1, R2, NOE, CCR for all residues

        Args:
            tau_c: correlation time (seconds)
            S2_array: (n,) order parameters
            delta_sigma: CSA anisotropy

        Returns:
            dict with predicted rates
        """
        # Geometry
        nh_vectors = compute_nh_vectors(
            self.spin_system.N_coords,
            self.spin_system.H_coords
        )
        nh_distances = compute_nh_distances(
            self.spin_system.N_coords,
            self.spin_system.H_coords
        )
        csa_axes = compute_csa_axes(
            nh_vectors,
            self.spin_system.CA_coords,
            self.spin_system.N_coords
        )
        csa_angles = compute_angles(nh_vectors, csa_axes)

        # Convert distances to meters
        r_NH_meters = nh_distances * 1e-10

        # Vectorized predictions
        R1 = vmap(R1_dipolar, in_axes=(0,None,0,None,None))(
            r_NH_meters, tau_c, S2_array, self.omega_H, self.omega_N
        )
        R2 = vmap(R2_dipolar, in_axes=(0,None,0,None,None))(
            r_NH_meters, tau_c, S2_array, self.omega_H, self.omega_N
        )
        NOE_vals = vmap(NOE, in_axes=(0,None,0,None,None))(
            r_NH_meters, tau_c, S2_array, self.omega_H, self.omega_N
        )
        CCR = vmap(CCR_DD_CSA, in_axes=(0,None,0,None,0,None))(
            r_NH_meters, tau_c, S2_array, self.omega_N, csa_angles, delta_sigma
        )

        return {
            'R1': R1,
            'R2': R2,
            'NOE': NOE_vals,
            'CCR': CCR,
            'geometry': {
                'nh_vectors': nh_vectors,
                'nh_distances': nh_distances,
                'csa_angles': csa_angles
            }
        }

# ============================================================================
# OPTIMIZATION / STRUCTURE REFINEMENT
# ============================================================================

def fit_dynamics(predictor, experimental_data, n_iterations=500):
    """
    Fit œÑc and S¬≤ to match experimental R1, R2, NOE
    This demonstrates the inverse problem
    """
    n_residues = len(predictor.spin_system)

    # Initialize parameters
    tau_c = 5e-9  # 5 ns initial guess
    S2_array = jnp.ones(n_residues) * 0.85

    # Setup optimizer
    optimizer = optax.adam(learning_rate=1e-7)
    params = {'tau_c': tau_c, 'S2': S2_array}
    opt_state = optimizer.init(params)

    # Loss function
    def loss_fn(params):
        predicted = predictor.predict_all(params['tau_c'], params['S2'])

        loss = 0.0
        for key in ['R1', 'R2', 'NOE']:
            if key in experimental_data:
                exp = experimental_data[key]
                pred = predicted[key]
                # Chi-squared loss
                loss += jnp.sum(((pred - exp) / 0.1)**2)

        return loss

    # Optimization loop
    losses = []
    for i in range(n_iterations):
        loss_val, grads = jax.value_and_grad(loss_fn)(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        losses.append(float(loss_val))

        if i % 100 == 0:
            print(f"Iteration {i}: Loss = {loss_val:.4f}, "
                  f"œÑc = {params['tau_c']*1e9:.2f} ns")

    return params, losses

# ============================================================================
# MAIN WORKFLOW
# ============================================================================

def main():
    """Run complete DiffRelax pipeline"""

    print("="*70)
    print("DiffRelax: Differentiable NMR Relaxation Engine")
    print("="*70)
    print(f"\nJAX backend: {jax.default_backend()}")
    print(f"Devices: {jax.devices()}\n")

    # Initialize checkpoint manager
    checkpoint = CheckpointManager()

    # PHASE 1: Data acquisition
    print("\n" + "="*70)
    print("PHASE 1: Data Acquisition")
    print("="*70)

    fetcher = BMRBFetcher(checkpoint)

    # Fetch GB3 (most studied protein)
    bmrb_id = 15477
    entry = fetcher.fetch_entry(bmrb_id)

    if not entry:
        print("‚úó Failed to fetch data")
        return

    relaxation_df = fetcher.extract_relaxation(entry)
    pdb_id = fetcher.get_pdb_id(entry)
    pdb_file = fetcher.download_pdb(pdb_id) if pdb_id else None

    if relaxation_df is None or pdb_file is None:
        print("‚úó Missing data")
        return

    print(f"\n‚úì Data acquired:")
    print(f"  Relaxation measurements: {len(relaxation_df)}")
    print(f"  PDB structure: {pdb_id}")

    # PHASE 2: Structure loading
    print("\n" + "="*70)
    print("PHASE 2: Structure Loading")
    print("="*70)

    spin_system = load_structure(pdb_file)
    print(f"‚úì Loaded {len(spin_system)} N-H spin pairs")

    # PHASE 3: Forward model test
    print("\n" + "="*70)
    print("PHASE 3: Forward Model Test")
    print("="*70)

    predictor = RelaxationPredictor(spin_system)

    # Test prediction with typical values
    tau_c_test = 5e-9  # 5 ns
    S2_test = jnp.ones(len(spin_system)) * 0.85

    predicted = predictor.predict_all(tau_c_test, S2_test)

    print(f"\n‚úì Predicted relaxation rates:")
    print(f"  R1: {jnp.mean(predicted['R1']):.2f} ¬± {jnp.std(predicted['R1']):.2f} s‚Åª¬π")
    print(f"  R2: {jnp.mean(predicted['R2']):.2f} ¬± {jnp.std(predicted['R2']):.2f} s‚Åª¬π")
    print(f"  NOE: {jnp.mean(predicted['NOE']):.3f} ¬± {jnp.std(predicted['NOE']):.3f}")
    print(f"  CCR: {jnp.mean(predicted['CCR']):.2f} ¬± {jnp.std(predicted['CCR']):.2f} s‚Åª¬π")

    # PHASE 4: Visualization
    print("\n" + "="*70)
    print("PHASE 4: Visualization")
    print("="*70)

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    residues = spin_system.residue_ids

    # Plot predictions
    axes[0,0].plot(residues, predicted['R1'], 'o-', alpha=0.6)
    axes[0,0].set_ylabel('R1 (s‚Åª¬π)')
    axes[0,0].set_title('Predicted R1')
    axes[0,0].grid(True, alpha=0.3)

    axes[0,1].plot(residues, predicted['R2'], 'o-', alpha=0.6, color='orange')
    axes[0,1].set_ylabel('R2 (s‚Åª¬π)')
    axes[0,1].set_title('Predicted R2')
    axes[0,1].grid(True, alpha=0.3)

    axes[1,0].plot(residues, predicted['NOE'], 'o-', alpha=0.6, color='green')
    axes[1,0].set_ylabel('NOE')
    axes[1,0].set_title('Predicted NOE')
    axes[1,0].grid(True, alpha=0.3)

    axes[1,1].plot(residues, predicted['CCR'], 'o-', alpha=0.6, color='red')
    axes[1,1].set_ylabel('CCR (s‚Åª¬π)')
    axes[1,1].set_title('Predicted CCR (DD-CSA)')
    axes[1,1].grid(True, alpha=0.3)
    axes[1,1].set_xlabel('Residue')

    plt.tight_layout()
    fig_path = checkpoint.base_dir / 'predictions.png'
    plt.savefig(fig_path, dpi=150)
    print(f"‚úì Saved figure: {fig_path}")
    plt.show()

    # PHASE 5: Gradient test
    print("\n" + "="*70)
    print("PHASE 5: Gradient Verification")
    print("="*70)

    def test_gradients():
        def loss_fn(coords):
            test_sys = SpinSystem(
                N_coords=jnp.array([[0., 0., 0.]]),
                H_coords=coords,
                CA_coords=jnp.array([[-1., 0., 0.]]),
                residue_ids=np.array([1])
            )
            test_pred = RelaxationPredictor(test_sys)
            result = test_pred.predict_all(5e-9, jnp.array([0.85]))
            return result['R1'][0]

        H_test = jnp.array([[0., 0., 1.02]])
        grad_fn = grad(loss_fn)
        grads = grad_fn(H_test)

        print(f"‚úì Gradients computed: {grads[0]}")
        print(f"‚úì All finite: {jnp.all(jnp.isfinite(grads))}")

        return jnp.all(jnp.isfinite(grads))

    grad_ok = test_gradients()

    # Save results
    results = {
        'bmrb_id': bmrb_id,
        'pdb_id': pdb_id,
        'n_residues': len(spin_system),
        'predicted_rates': {k: np.array(v) if isinstance(v, jnp.ndarray) else v
                           for k, v in predicted.items() if k != 'geometry'},
        'gradients_ok': bool(grad_ok)
    }

    checkpoint.save(results, 'results_phase1')

    # SUMMARY
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"\n‚úì Successfully completed Phase 1:")
    print(f"  - Downloaded BMRB {bmrb_id} and PDB {pdb_id}")
    print(f"  - Loaded {len(spin_system)} residues")
    print(f"  - Built differentiable forward model")
    print(f"  - Predicted R1, R2, NOE, CCR")
    print(f"  - Verified gradient flow: {'‚úì' if grad_ok else '‚úó'}")
    print(f"\nüìÅ Results saved to: {checkpoint.base_dir}")
    print(f"\nüéØ NEXT STEPS:")
    print(f"  1. Compare predictions to experimental data")
    print(f"  2. Fit dynamics parameters (œÑc, S¬≤)")
    print(f"  3. Refine structure using gradients")
    print(f"  4. Analyze when CCR is essential")
    print(f"  5. Generate paper figures")
    print("\n" + "="*70)

if __name__ == "__main__":
    main()

DiffRelax: Setting up environment...
Installing jax[cuda12]...
Installing optax...
Installing equinox...
Installing biopython...
Installing requests...
Installing pandas...
Installing matplotlib...
Installing seaborn...
‚úì All packages installed

DiffRelax: Differentiable NMR Relaxation Engine

JAX backend: gpu
Devices: [CudaDevice(id=0)]

Checkpoint directory: /content/diffrelax_checkpoints

PHASE 1: Data Acquisition
Downloading BMRB 15477...
‚úì Saved checkpoint: bmrb_15477
‚úó Missing data
