In [2]:

import jax.numpy as jnp
from jax.scipy.linalg import eigh
from jax import grad, jacobian, value_and_grad, jit, lax, vmap, value_and_grad
import MDAnalysis as mda
import re
import jax
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Extract LJ parameters from topology file converted to txt.
def extract_lj_parameters(path):
    with open(path, 'r') as file:
        data = file.read()
    lj_sr_pattern = r"type=(\S+).*?c6=\s*([0-9\.\-eE+]+).*?c12=\s*([0-9\.\-eE+]+)"
    lj_sr_matches = re.findall(lj_sr_pattern, data)
    lj_params = {atom_type: {"c6": float(c6), "c12": float(c12)} for atom_type, c6, c12 in lj_sr_matches}
    return lj_params

In [None]:
tpr_file = "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/md.tpr"
xtc_file = "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/md.xtc"
lj_contents = "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/tpr_contents.txt"
lj_params = extract_lj_parameters(lj_contents)


In [None]:
#Introducing the system
u = mda.Universe(tpr_file, xtc_file )
print(f"{u}, unlike in the paper it reports 39768 atoms\n  However the protein atoms is {len(u.select_atoms('protein'))} same as the paper" )
time_step = u.trajectory.dt 
print(f"The time step for the trajectories is : {time_step} ps" )
print(f"The system has {u.trajectory.n_frames} frames")
print(f"The total simulation time is 81 x 5 = {len(u.trajectory)*u.trajectory.dt} ps\n   The paper reports saving a frame every 400ps which mean in this work we can only reproduce the first frame , equivilant to frame number 200.\n   However to test the code, more frames are needed, therefore, a frame every 40ps will be selected. ")
protein = u.select_atoms("protein")

protein_residues = protein.residues
print(f"The system has {len(protein_residues)} residues")
residue_atoms = []
for residue in protein.residues:
    residue_atoms.append(residue.atoms)
n_residues = len(protein_residues)
for residue in protein_residues:
    print(f"{ residue.resid}, {residue.resname}")

In [None]:
# The distance matrix, a direct function of the positions
def compute_distance_matrix(positions):
    """
    Compute the pairwise distance matrix for a set of atomic positions.

    The distance matrix is a symmetric matrix where the element (i, j) represents 
    the Euclidean distance between atom `i` and atom `j`. This function is a direct 
    function of the positions provided.

    Parameters
    ----------
    positions : jnp.ndarray
        A 2D array of shape (n_atoms, 3) containing the Cartesian coordinates 
        (x, y, z) of each atom.

    Returns
    -------
    distance_matrix : jnp.ndarray
        A symmetric matrix of shape (n_atoms, n_atoms) where each element 
        represents the pairwise Euclidean distance between atoms.
    
    Notes
    -----
    - The diagonal elements of the distance matrix are all zeros, as the distance 
      between an atom and itself is zero.
    - The function uses JAX for efficient computation, enabling potential 
      acceleration via just-in-time compilation or parallelization.

    """

    n_atoms = positions.shape[0]
    distance_matrix = jnp.zeros((n_atoms, n_atoms))

    for i in range(n_atoms):
        for j in range(i + 1, n_atoms):
            r = jnp.linalg.norm(positions[i] - positions[j])
            distance_matrix = distance_matrix.at[i, j].set(r)
            distance_matrix = distance_matrix.at[j, i].set(r)
    
    return distance_matrix

In [None]:
def compute_interaction_energy_matrix(distance_matrix, residue_atoms, lj_params):
    """
    Compute the pairwise interaction energy matrix for non-contact interactions based on a precomputed distance matrix.

    This function calculates both Lennard-Jones (van der Waals) and Coulombic interactions
    between residues, considering the interactions between all atom pairs within each residue.
    The resulting energy matrix is symmetric and represents the total interaction energy 
    between residues.

    Parameters
    ----------
    distance_matrix : jnp.ndarray
        A precomputed symmetric matrix of pairwise Euclidean distances between atoms.
    residue_atoms : list
        A list of atoms grouped by residue, where each entry corresponds to the atoms
        belonging to a particular residue.
    lj_params : dict
        A dictionary containing Lennard-Jones parameters for atom types and their charges.
        Each atom type must have the keys:
        - "c6": Lennard-Jones attractive parameter.
        - "c12": Lennard-Jones repulsive parameter.
        - "charges": Atomic charges for Coulombic interactions.

    Returns
    -------
    energy_matrix : jnp.ndarray
        A symmetric matrix of shape (n_residues, n_residues), where each element represents
        the total interaction energy (Lennard-Jones + Coulomb) between two residues.

    """
    n_residues = len(residue_atoms)
    energy_matrix = jnp.zeros((n_residues, n_residues))
    conversion_factor = 138.935485  # Electric conversion factor in KJ.mol.e^-2.nm^-1

    for i in range(n_residues):
        for j in range(i + 1, n_residues):
            E_vdw = 0.0
            E_coul = 0.0

            # Iterate over atom pairs within the residues
            for atom_i in residue_atoms[i]:
                for atom_j in residue_atoms[j]:
                    r = distance_matrix[atom_i.index, atom_j.index]
                    nonzero_r = jnp.where(r > 0, r, 1e-8)

                    # Lennard-Jones Parameters
                    type_i = atom_i.type
                    type_j = atom_j.type
                    if type_i in lj_params and type_j in lj_params:
                        c6_i, c12_i = lj_params[type_i]["c6"], lj_params[type_i]["c12"]
                        c6_j, c12_j = lj_params[type_j]["c6"], lj_params[type_j]["c12"]

                        sigma_i = (c12_i / c6_i) ** (1 / 6)
                        epsilon_i = c6_i**2 / (4 * c12_i)
                        sigma_j = (c12_j / c6_j) ** (1 / 6)
                        epsilon_j = c6_j**2 / (4 * c12_j)
                        
                        #Lorentz-Berthelot mixing rules 
                        sigma_ij = (sigma_i + sigma_j) / 2
                        epsilon_ij = jnp.sqrt(epsilon_i * epsilon_j)

                        # Lennard-Jones potential
                        E_vdw += jnp.where(
                            r > 0,
                            4 * epsilon_ij * ((sigma_ij / nonzero_r)**12 - (sigma_ij / nonzero_r)**6),
                            0.0
                        )

                    # Coulomb Interaction
                    q_i = atom_i.charge
                    q_j = atom_j.charge
                    E_coul += jnp.where(
                        r > 0,
                        conversion_factor * (q_i * q_j) / nonzero_r,
                        0.0
                    )

            # Update energy matrix symmetrically
            total_energy = E_vdw + E_coul
            energy_matrix = energy_matrix.at[i, j].set(total_energy)
            energy_matrix = energy_matrix.at[j, i].set(total_energy)

    return energy_matrix

In [None]:
# Compute Eigenvalues, ENG, and SDENG
def compute_eigenvalues(energy_matrix):
    """
    Compute the eigenvalues of the interaction energy matrix.

    Parameters
    ----------
    energy_matrix : jnp.ndarray
        A symmetric interaction energy matrix of shape (n_residues, n_residues).

    Returns
    -------
    eigenvalues : jnp.ndarray
    
    Notes
    -----
    The eigenvalues are computed using JAX's `linalg.eigh` method.
    """
    return jnp.linalg.eigh(energy_matrix)[0]



def compute_eng_sdeng(eigenvalues):
    """
    Compute the normalized energy gap (ENG) and the standard deviation of eigenvalues (SDENG).

    Parameters
    ----------
    eigenvalues : jnp.ndarray
        A 1D array of eigenvalues obtained from the interaction energy matrix.

    Returns
    -------
    eng : float
        The normalized energy gap, calculated as the spectral gap divided by the
        average separation between adjacent eigenvalues. Returns 0.0 if the average
        separation is non-positive.
    sdeng : float
        The standard deviation of the eigenvalues.

    Notes
    -----
    - The spectral gap is the difference between the two most negative eigenvalues.
    - The average separation is computed as the mean of the differences between adjacent
      eigenvalues.
    """
    sorted_eigenvalues = jnp.sort(eigenvalues)
    spectral_gap = sorted_eigenvalues[1] - sorted_eigenvalues[0]
    avg_separation = jnp.mean(jnp.diff(sorted_eigenvalues))

    eng = jnp.where(avg_separation > 0, spectral_gap / avg_separation, 0.0)
    sdeng = jnp.std(eigenvalues)
    return eng, sdeng

# Compute Collective Variable (CV)
def compute_cv(eng, sdeng, alpha, beta):
    """
    Compute the Collective Variable (CV) as a linear combination of ENG and SDENG.

    Parameters
    ----------
    eng : float
        The spectral gap, computed from eigenvalues.
    sdeng : float
        The standard deviation of the eigenvalues.
    alpha : float
        Weighting factor for the ENG term.
    beta : float
        Weighting factor for the SDENG term.

    Returns
    -------
    float
        The computed Collective Variable (CV) value, given by:
        CV = alpha * ENG - beta * SDENG
    """
    return alpha * eng - beta * sdeng


In [None]:
# Modulate Weights
def modulate_weights_with_probability(eng_series, sdeng_series, frame_idx, alpha, beta, percentages):
    """
    Modulate the weights (alpha and beta) based on ENG and SDENG values at a given frame,
    relative to dynamically computed thresholds.

    Parameters
    ----------
    eng_series : jnp.ndarray
        Time series of ENG (spectral gap) values across frames.
    sdeng_series : jnp.ndarray
        Time series of SDENG (standard deviation of eigenvalues) values across frames.
    frame_idx : int
        The current frame index in the simulation.
    alpha : float
        The current weight for the ENG term.
    beta : float
        The current weight for the SDENG term.
    percentages : list of float
        A list of percentages used to adjust the thresholds for modulation.

    Returns
    -------
    tuple of (float, float)
        Updated values of alpha and beta after applying the modulation rule.

   
    """
    # Modulate weights based on probabilities

    eng_max = jnp.max(eng_series)
    sdeng_min = jnp.min(sdeng_series)
    msd = jnp.std(eng_series)

    n = percentages[frame_idx % len(percentages)]
    eng_threshold = eng_max - (n / 100) * msd
    sdeng_threshold = sdeng_min + (n / 100) * msd

    eng_prob = eng_series[frame_idx] > eng_threshold
    sdeng_prob = sdeng_series[frame_idx] < sdeng_threshold

    if eng_prob:
        alpha *= 1.1
    if sdeng_prob:
        beta *= 0.9

    return alpha, beta


def compute_cv_and_dependencies(positions, n_atoms, lj_params, alpha, beta):
    """
    Compute CV and its dependencies based on atomic positions.

    Parameters:
    ----------
    positions : jnp.ndarray
        Atomic positions (n_atoms, 3).
    n_atoms : int
        Number of atoms.
    lj_params : dict
        Lennard-Jones parameters.
    alpha : float
        Weight for ENG.
    beta : float
        Weight for SDENG.

    Returns:
    -------
    tuple
        CV value, ENG, SDENG, and other dependencies.
    """
    distance_matrix = compute_distance_matrix(positions)

    energy_matrix = compute_interaction_energy_matrix(distance_matrix, residue_atoms, lj_params)

    eigenvalues = compute_eigenvalues(energy_matrix)
    eng, sdeng = compute_eng_sdeng(eigenvalues)

    cv_value = compute_cv(eng, sdeng, alpha, beta)

    return cv_value, eng, sdeng, distance_matrix, energy_matrix



In [None]:
def compute_cv_gradient_per_frame(positions, n_atoms, lj_params, alpha, beta):
    """
    Compute the gradient of the CV with respect to atomic positions for one frame.

    Parameters:
    ----------
    positions : jnp.ndarray
        Atomic positions (n_atoms, 3).
    n_atoms : int
        Number of atoms.
    lj_params : dict
        Lennard-Jones parameters.
    alpha : float
        Weight for ENG.
    beta : float
        Weight for SDENG.

    Returns:
    -------
    jnp.ndarray
        Gradient of the CV with respect to atomic positions.
    """
    def cv_function(positions_flat):
        # Reshape positions to (n_atoms, 3)
        positions = positions_flat.reshape((n_atoms, 3))
        cv_value, _, _, _, _ = compute_cv_and_dependencies(positions, n_atoms, lj_params, alpha, beta)
        return cv_value

    # Flatten positions for JAX compatibility
    positions_flat = positions.flatten()
    return grad(cv_function)(positions_flat).reshape((n_atoms, 3))

In [None]:
def main(u, protein, lj_params):
    """
    Main function to compute the Collective Variable (CV), its gradient, and modulate weights
    for a molecular dynamics (MD) trajectory.

    Parameters
    ----------
    u : MDAnalysis.Universe
        The MD trajectory universe object.
    protein : MDAnalysis.AtomGroup
        The protein AtomGroup containing atoms of interest.
    lj_params : dict
        Lennard-Jones interaction parameters for the simulation.

    Returns
    -------
    None
    
    """
    time_step = u.trajectory.dt
    selected_frames = range(0, u.trajectory.n_frames, int(40 / time_step))
    percentages = list(range(0, 101, 1))

    alpha, beta = 1.0, 1.0
    eng_time_series = jnp.zeros(len(selected_frames))
    sdeng_time_series = jnp.zeros(len(selected_frames))
    for frame_idx, frame in enumerate(selected_frames):
        u.trajectory[frame]

        # Extract atomic positions
        positions = jnp.array([atom.position for atom in protein.atoms])
        n_atoms = len(protein.atoms)

        # Compute CV and Dependencies
        
        cv_value, eng, sdeng, _, _ = compute_cv_and_dependencies(positions, n_atoms, lj_params, alpha, beta)
         

         

        # Store Time Series Data
        eng_time_series = eng_time_series.at[frame_idx].set(eng)
        sdeng_time_series = sdeng_time_series.at[frame_idx].set(sdeng)

        # print(eng_time_series)

        # Modulate Weights
        alpha, beta = modulate_weights_with_probability(
            eng_time_series, sdeng_time_series, frame_idx, alpha, beta, percentages
        )

        # Compute CV Gradient
        cv_grad = compute_cv_gradient_per_frame(positions, n_atoms, lj_params, alpha, beta)

        # Print results
        print(f"Frame {frame_idx}: CV(T)={cv_value}, ENG(T)={eng}, SDENG(T)={sdeng}, Alpha={alpha}, Beta={beta}")
        print(f"Frame {frame_idx}: Gradient Norm = {jnp.linalg.norm(cv_grad)}")

In [None]:
sim_run = main(u, protein,lj_params)
sim_run
