In [13]:
import jax.numpy as jnp
from jax import grad
import re
import scipy.constants as sc
import MDAnalysis as mda
import numpy as np

In [14]:
# pip install -U jax

In [20]:


# Compute interaction energy matrix using flattened coordinates
def compute_interaction_energy_matrix_flat(coords, residues, lj_params, dielectric=78.5):
    """
    Computes the interaction energy matrix from flattened coordinates.
    """
    n_residues = len(residues)
    energy_matrix = jnp.zeros((n_residues, n_residues))

    # Coulomb conversion factor
    conversion_factor = (
        1 / (4 * jnp.pi * sc.epsilon_0 * dielectric)
        * sc.elementary_charge**2
        / sc.Avogadro
        / 1000
    )  # kJ/mol

    for i in range(n_residues):
        for j in range(i + 1, n_residues):
            E_vdw = 0.0
            E_coul = 0.0

            # Pairwise atom interactions
            for atom_i_idx, atom_i in enumerate(residues[i].atoms):
                for atom_j_idx, atom_j in enumerate(residues[j].atoms):
                    # Calculate the distance using flattened coordinates
                    pos_i = coords[atom_i_idx]
                    pos_j = coords[atom_j_idx]
                    r = jnp.linalg.norm(pos_i - pos_j)

                    if r > 0:
                        # Lennard-Jones parameters
                        type_i = atom_i.type
                        type_j = atom_j.type

                        if type_i in lj_params and type_j in lj_params:
                            c6_i = lj_params[type_i]["c6"]
                            c12_i = lj_params[type_i]["c12"]
                            c6_j = lj_params[type_j]["c6"]
                            c12_j = lj_params[type_j]["c12"]

                            sigma_i = (c12_i / c6_i) ** (1 / 6)
                            epsilon_i = c6_i ** 2 / (4 * c12_i)
                            sigma_j = (c12_j / c6_j) ** (1 / 6)
                            epsilon_j = c6_j ** 2 / (4 * c12_j)

                            sigma_ij = (sigma_i + sigma_j) / 2
                            epsilon_ij = jnp.sqrt(epsilon_i * epsilon_j)

                            # Lennard-Jones potential
                            E_vdw += 4 * epsilon_ij * ((sigma_ij / r)**12 - (sigma_ij / r)**6)

                        # Coulomb interaction
                        q_i = atom_i.charge
                        q_j = atom_j.charge
                        E_coul += conversion_factor * (q_i * q_j) / r

            # Fill energy matrix symmetrically
            energy_matrix = energy_matrix.at[i, j].set(E_vdw + E_coul)
            energy_matrix = energy_matrix.at[j, i].set(E_vdw + E_coul)

    return energy_matrix


# # Calculate ENG(t) with flattened coordinates
# def calculate_eng_t_flat(coords, residues, lj_params):
#     """
#     Computes ENG(t) using flattened atomic coordinates.
#     """
#     # Compute the interaction energy matrix
#     interaction_matrix = compute_interaction_energy_matrix_flat(coords, residues, lj_params)

#     # Compute eigenvalues
#     eigenvalues = jnp.linalg.eigh(interaction_matrix)[0]  # Only eigenvalues
#     eigenvalues = jnp.sort(eigenvalues)  # Ensure sorted order

#     # Calculate spectral gap and average separation
#     spectral_gap = eigenvalues[1] - eigenvalues[0]
#     avg_separation = jnp.mean(jnp.diff(eigenvalues))

#     # Calculate and return ENG(t)
#     eng_t = spectral_gap / avg_separation
#     return eng_t


# Compute ENG(t), eigenvalues, and eigenvectors
def calculate_eng_t_flat(coords, residues, lj_params):
    interaction_matrix = compute_interaction_energy_matrix_flat(coords, residues, lj_params)
    eigenvalues, eigenvectors = jnp.linalg.eigh(interaction_matrix)
    eigenvalues = jnp.sort(eigenvalues)

    spectral_gap = eigenvalues[1] - eigenvalues[0]
    avg_separation = jnp.mean(jnp.diff(eigenvalues))
    eng_t = spectral_gap / avg_separation

    return eng_t, eigenvalues, eigenvectors



# # Compute Correlation Distance (CD)
# def compute_cd(eng_time_series, residue_time_series):
#     covariance = jnp.cov(eng_time_series, residue_time_series)[0, 1]
#     std_x = jnp.std(eng_time_series)
#     std_y = jnp.std(residue_time_series)
#     rho = covariance / (std_x * std_y)
#     cd = jnp.sqrt(2 * (1 - rho))
#     return cd




def compute_cd(eng_time_series, residue_time_series):
    """
    Compute the Correlation Distance (CD) between two time series: ENG(t) and a residue's component.
    
    Parameters:
    -----------
    eng_time_series : list or jnp.ndarray
        The time series of ENG(t) values.
    residue_time_series : list or jnp.ndarray
        The time series of a residue's energy contributions.

    Returns:
    --------
    cd : float
        The Correlation Distance between the two time series.
    """
    # Convert inputs to JAX arrays if they are not already
    eng_time_series = jnp.array(eng_time_series)
    residue_time_series = jnp.array(residue_time_series)

    # Compute Pearson correlation coefficient (rho)
    covariance = jnp.cov(eng_time_series, residue_time_series)[0, 1]
    std_x = jnp.std(eng_time_series)
    std_y = jnp.std(residue_time_series)
    rho = covariance / (std_x * std_y)

    # Compute Correlation Distance (CD)
    cd = jnp.sqrt(2 * (1 - rho))
    return cd





def fine_tuning_loop_2(eng_time_series, residue_time_series, alpha, beta):
    cd = compute_cd(eng_time_series, residue_time_series)
    SQRT_2 = jnp.sqrt(2)

    if cd > SQRT_2:
        alpha += 0.1  # Increase alpha for better exploration
    else:
        beta += 0.1  # Increase beta to reduce fluctuations

    return alpha, beta


def process_trajectory(u, protein_residues, lj_params):
    eng_time_series = []
    eigenvectors_series = []
    alpha, beta = 1.0, 1.0

    for ts in u.trajectory:
        coords = np.array([atom.position for atom in protein_residues.atoms])
        coords_flat = jnp.array(coords.reshape(-1, 3))

        eng_t, eigenvalues, eigenvectors = calculate_eng_t_flat(coords_flat, protein_residues, lj_params)
        eng_time_series.append(eng_t)
        eigenvectors_series.append(eigenvectors[:, 0])  # Only the first eigenvector

        if len(eng_time_series) > 1:
            alpha, beta = fine_tuning_loop_2(
                eng_time_series[-2:],  # Compare the last two ENG(t) values
                eigenvectors_series[-2:],  # Compare the last two eigenvector components
                alpha,
                beta,
            )

    return eng_time_series, alpha, beta


# Compute ENG(t) and its gradient
def calculate_eng_t_with_grad_flat(coords, residues, lj_params):
    """
    Computes ENG(t) and its gradient using flattened atomic coordinates.
    """
    eng_t_value = calculate_eng_t_flat(coords, residues, lj_params)
    eng_t_grad = grad(calculate_eng_t_flat)(coords, residues, lj_params)
    return eng_t_value, eng_t_grad


# Extract Lennard-Jones parameters from topology file
def extract_lj_parameters(lj_parameters_file ):
    """
    Extracts c6 and c12 from the topology file using regex.
    """
    with open(lj_parameters_file , 'r') as file:
        data = file.read()

    # Regular expression patterns for LJ_SR types
    lj_sr_pattern = r"c6=\s*([0-9\.\-eE+]+),\s*c12=\s*([0-9\.\-eE+]+)"
    lj_sr_matches = re.findall(lj_sr_pattern, data)

    lj_params = []
    for match in lj_sr_matches:
        c6, c12 = match
        lj_params.append({
            "type": "LJ_SR",
            "c6": float(c6),
            "c12": float(c12)
        })

    return lj_params



def coarse_tuning_loop(eng_time_series, sdeng_time_series, percentages):
    """
    Coarse-grain tuning loop to compute folded probability based on ENG and SDENG thresholds.
    
    Parameters:
    -----------
    eng_time_series : list
        List of ENG(t) values across all frames.
    sdeng_time_series : list
        List of SDENG values across all frames.
    percentages : list
        List of percentages for threshold adjustment.

    Returns:
    --------
    folded_probabilities : list
        Folded probabilities for each percentage.
    eng_counts : list
        Number of frames satisfying the ENG threshold for each percentage.
    sdeng_counts : list
        Number of frames satisfying the SDENG threshold for each percentage.
    """
    eng_max = max(eng_time_series)
    sdeng_min = min(sdeng_time_series)
    msd = np.std(eng_time_series)  # Mean standard deviation

    eng_counts = []
    sdeng_counts = []
    folded_probabilities = []

    for n in percentages:
        # Calculate thresholds
        eng_threshold = eng_max - (n / 100) * msd
        sdeng_threshold = sdeng_min + (n / 100) * msd

        # Count frames satisfying thresholds
        eng_count = sum(1 for eng in eng_time_series if eng > eng_threshold)
        sdeng_count = sum(1 for sdeng in sdeng_time_series if sdeng < sdeng_threshold)

        # Calculate folded probability
        total_frames = len(eng_time_series)
        folded_probability = (eng_count + sdeng_count) / (2 * total_frames)

        # Append results
        eng_counts.append(eng_count)
        sdeng_counts.append(sdeng_count)
        folded_probabilities.append(folded_probability)

    return folded_probabilities, eng_counts, sdeng_counts


# Process trajectory and integrate coarse tuning loop
def process_trajectory_with_coarse_tuning(u, protein_residues, lj_params, percentages):
    eng_time_series = []
    sdeng_time_series = []
    eigenvectors_series = []
    alpha, beta = 1.0, 1.0

    for ts in u.trajectory:
        coords = np.array([atom.position for atom in protein_residues.atoms])
        coords_flat = jnp.array(coords.reshape(-1, 3))

        eng_t, eigenvalues, eigenvectors = calculate_eng_t_flat(coords_flat, protein_residues, lj_params)
        eng_time_series.append(eng_t)
        sdeng_time_series.append(jnp.std(eigenvalues))  # Standard deviation of eigenvalues
        eigenvectors_series.append(eigenvectors)

        if len(eng_time_series) > 1:
            alpha, beta = fine_tuning_loop_2(
                eng_time_series[-2:],
                eigenvectors_series[-2:],
                alpha,
                beta,
            )

    # Perform coarse tuning loop
    folded_probabilities, eng_counts, sdeng_counts = coarse_tuning_loop(
        eng_time_series, sdeng_time_series, percentages
    )

    return eng_time_series, sdeng_time_series, folded_probabilities, eng_counts, sdeng_counts, alpha, beta








In [21]:
tpr_file= "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/md.tpr"
xtc_file = "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/md.xtc"
lj_parameters_file = "/Users/sss/Documents/EnergyGap_project/ENGgromax_output_files/tpr_contents.txt"

In [22]:
# # Load the molecular dynamics trajectory
# u = mda.Universe(tpr_file, xtc_file)

# # Select the protein residues
# protein = u.select_atoms("protein")
# protein_residues = protein.residues

# # Example LJ parameters for testing (replace with actual extracted parameters)
# # lj_params = {"type1": {"c6": 1.0, "c12": 1.0}, "type2": {"c6": 0.8, "c12": 0.5}}
# lj_params = extract_lj_parameters(lj_parameters_file)
# # Extract coordinates for the first frame
# u.trajectory[0]
# coords = np.array([atom.position for atom in protein.atoms])

# # Flatten coordinates for input into JAX
# coords_flat = jnp.array(coords.reshape(-1, 3))

# # Calculate ENG(t) and its gradient
# eng_t, gradient = calculate_eng_t_with_grad_flat(coords_flat, protein_residues, lj_params)

# print(f"ENG(t): {eng_t}")
# print(f"Gradient of ENG(t): {gradient}")


In [23]:
u = mda.Universe(tpr_file, xtc_file)
protein = u.select_atoms("protein")
protein_residues = protein.residues
lj_params = extract_lj_parameters(lj_parameters_file)

percentages = [10, 20, 30, 40, 50]  # Example percentages for coarse tuning

results = process_trajectory_with_coarse_tuning(u, protein_residues, lj_params, percentages)
eng_time_series, sdeng_time_series, folded_probabilities, eng_counts, sdeng_counts, alpha, beta = results

print(f"ENG(t) Time Series: {eng_time_series}")
print(f"SDENG Time Series: {sdeng_time_series}")
print(f"Folded Probabilities: {folded_probabilities}")
print(f"ENG Counts: {eng_counts}")
print(f"SDENG Counts: {sdeng_counts}")
print(f"Final Alpha: {alpha}")
print(f"Final Beta: {beta}")

ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2, 20, 20)]