# Optimizing the parameters of a coarse-grained potential using jaxDNA

jaxDNA is a sandbox for developing, parameterizing, and testing biomolecular models. It currently has the oxDNA family of models implemented (oxDNA [refs], oxRNA [refs], oxNA [refs]) and extensions to lipid and protein models are actively being pursued. The philosophy of the software is as follows:

jax_dna_opt_diagram (1).svg

Currently, the supported *simulators* are **Jax-MD** and the **oxDNA** standalone C++/CUDA codes, but extensions to support **GROMACS** and **OpenMM** are being developed.

An *optimizer* computes gradients of the chosen *objective* functions with respect to the parameters, then performs gradient descent with an optimizer of choice (e.g. RMSProp, Adam, etc). There are two strategies for computing gradients implemented at present: **direct automatic differentiation** and **differentiable trajectory reweighting (DiffTRE)**[refs].

It is only possible to use **direct automatic differentiation** with the **Jax-MD** simulator because other MD codebases (including **oxDNA**) are not differentiable. In practice, we have found that the oxDNA energy function is sufficiently complex that computing gradients directly through unrolled simulated trajectories leads to exploding gradients unless timesteps are taken very small, so small as to render calculations very computationally expensive. An alternative is to use the **DiffTRE** method of XX and XX [ref], which expresses observables under different parameter sets as weighted averages over a given ensemble of simulated trajectories. The ensemble is re-sampled only periodically. This method must be implemented in a differentiable language like Jax-MD because gradients of energy functions and of the average over sampled states are still computed automatically.

While the software implements the oxDNA family of models as a starting point, it flexibly supports custom geometries and energy functions. One can also augment existing oxDNA models with custom energy functions, as we'll see below.

This tutorial is meant to introduce the user to the power of jaxDNA by (i) implementing a simple reparameterization of the oxDNA 1.0 model to produce a different DNA helical pitch value and (ii) adding a custom energy function to the oxDNA model and parameterizing it.

## Imports & Utils

In [None]:
#install from Github repo
# Need to uninstall the current jax_md due to a version conflict
!pip uninstall -y jax_md
# install the latest version directly from GitHub
!pip install git+https://github.com/jax-md/jax-md.git
!pip install ray

#eventually will want to use wget to get the top and conf files:
#!wget https://raw.githubusercontent.com/rkruegs123/jax-dna-dev/ssec-jax-dna-staging/data/sys-defs/simple-helix/sys.top -O sys.top
#!wget https://raw.githubusercontent.com/rkruegs123/jax-dna-dev/ssec-jax-dna-staging/data/sys-defs/simple-helix/bound_relaxed.conf -O bound_relaxed.conf

Collecting git+https://mc2engel:****@github.com/rkruegs123/jax-dna-dev.git@ssec-jax-dna-staging
  Cloning https://mc2engel:****@github.com/rkruegs123/jax-dna-dev.git (to revision ssec-jax-dna-staging) to /tmp/pip-req-build-pjrs554d
  Running command git clone --filter=blob:none --quiet 'https://mc2engel:****@github.com/rkruegs123/jax-dna-dev.git' /tmp/pip-req-build-pjrs554d
  Running command git checkout -b ssec-jax-dna-staging --track origin/ssec-jax-dna-staging
  Switched to a new branch 'ssec-jax-dna-staging'
  Branch 'ssec-jax-dna-staging' set up to track remote branch 'ssec-jax-dna-staging' from 'origin'.
  Resolved https://mc2engel:****@github.com/rkruegs123/jax-dna-dev.git to commit ce8247cdd88dc70bd73c9d658d2dc97e72efa450
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jax_md (from jax_dna==0.1.dev1293+gce8247c)
  Downloading jax_md-0.2.8-py3-non

Found existing installation: jax-md 0.2.8
Uninstalling jax-md-0.2.8:
  Successfully uninstalled jax-md-0.2.8
Collecting git+https://github.com/jax-md/jax-md.git
  Cloning https://github.com/jax-md/jax-md.git to /tmp/pip-req-build-q8ghnvx1
  Running command git clone --filter=blob:none --quiet https://github.com/jax-md/jax-md.git /tmp/pip-req-build-q8ghnvx1
  Resolved https://github.com/jax-md/jax-md.git to commit 3571f47bd37a4a1d316d68797ed16f52add5e743
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pymatgen (from jax-md==0.2.8)
  Downloading pymatgen-2025.3.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting monty>=2025.1.9 (from pymatgen->jax-md==0.2.8)
  Downloading monty-2025.3.3-py3-none-any.whl.metadata (3.6 kB)
Collecting palettable>=3.3.3 (from pymatgen->jax-md==0.2.8)
  Downloading palettable-3.3.3-py2.py3-none-any.whl.metadata (3.3 kB)
Collecting pybtex>=0.24.0 (from pymatgen->jax-md==0.2.8)
  Downloading pybtex-0.24.0-py

In [2]:
import functools
import itertools
import logging
import os
from pathlib import Path
import typing
import jax
import jax.numpy as jnp
import jax_md
import optax
import sys
from tqdm import tqdm
import operator


import jax_dna
import jax_dna.energy as jdna_energy
import jax_dna.energy.dna1 as dna1_energy
import jax_dna.input.toml as toml_reader
import jax_dna.input.tree as jdna_tree
import jax_dna.observables as jd_obs
import jax_dna.observables.pitch as pitch
import jax_dna.optimization.simulator as jdna_simulator
import jax_dna.optimization.objective as jdna_objective
import jax_dna.optimization.optimization as jdna_optimization
import jax_dna.simulators.oxdna as oxdna
import jax_dna.simulators.io as jdna_sio
import jax_dna.simulators.jax_md as jdna_jaxmd
import jax_dna.utils.types as jdna_types
from jax_dna.ui.loggers.logger import Logger
from jax_dna.input import topology, trajectory
import jax_dna.input.topology as jdna_top
import jax_dna.input.trajectory as jdna_traj
import jax_dna.ui.loggers.jupyter as jdna_jupyter

jax.config.update("jax_enable_x64", True)

In [3]:
##for generating plots

from google.colab import output
output.enable_custom_widget_manager()
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time

# Initialize data storage
loss_x = []
loss_y = []
pitch_x = []
pitch_y = []

# Function to create and display figure
def update_plots(TARGET_PITCH, N_ITERS):
    # Create figure with 2 subplots side by side
    fig = make_subplots(rows=1, cols=2,
                        subplot_titles=("Loss", "Pitch"),
                        horizontal_spacing=0.2)

    # Add traces for each subplot
    fig.add_trace(go.Scatter(x=loss_x, y=loss_y, mode='lines+markers', name='Loss'),
                 row=1, col=1)

    fig.add_trace(go.Scatter(x=pitch_x, y=pitch_y, mode='lines+markers', name='Pitch'),
                 row=1, col=2)

    # Add target line
    fig.add_trace(
        go.Scatter(
            x=[-1, N_ITERS],
            y=[TARGET_PITCH, TARGET_PITCH],
            mode='lines',
            line=dict(color='red', width=2, dash='dash'),
            name='Target Pitch',
            showlegend=True
        ),
        row=1, col=2
    )

    fig.update_layout(
        autosize=False,
        width=900,  # Increased width to make more room
        height=450,  # Increased height proportionally
        margin=dict(l=50, r=50, t=60, b=50),
    )

    # Set x-axis labels and ranges
    fig.update_xaxes(title_text="Iteration", range=[-1, N_ITERS], row=1, col=1)
    fig.update_xaxes(title_text="Iteration", range=[-1, N_ITERS], row=1, col=2)

    # Set y-axis labels and ranges
    fig.update_yaxes(title_text="Loss (A.U.)", range=[0, 0.12], row=1, col=1)
    fig.update_yaxes(title_text="Pitch (bps/turn)", range=[9.5, 13.5], row=1, col=2)


    # Clear previous output and display new figure
    clear_output(wait=True)
    display(fig)


## Simulate a simple helix with Jax-MD

## Optimize the oxDNA model parameters to tighten the pitch of the helix

#### parameters:

In [4]:
##set up optimization parameters
opt_config = {
        "n_sim_steps": 5_000, #how long to run each simulation before calculating gradients
        "n_opt_steps": 25, #how many optimization steps to run for
        "learning_rate": 0.001, #for use with optimizer
}
TARGET_NUM_BPS_PER_TURN=12.
TARGET_PITCH = 2.0*jnp.pi/TARGET_NUM_BPS_PER_TURN

##parameters for simulation:
simulation_config, _ = dna1_energy.default_configs()
dt = simulation_config["dt"]
kT = simulation_config["kT"]
diff_coef = simulation_config["diff_coef"]
rot_diff_coef = simulation_config["rot_diff_coef"]

gamma = jax_md.rigid_body.RigidBody(
      center=jnp.array([kT / diff_coef], dtype=jnp.float64),
      orientation=jnp.array([kT / rot_diff_coef], dtype=jnp.float64),
  )
mass = jax_md.rigid_body.RigidBody(
      center=jnp.array([simulation_config["nucleotide_mass"]], dtype=jnp.float64),
      orientation=jnp.array([simulation_config["moment_of_inertia"]], dtype=jnp.float64),
  )

##load the initial configuration and topology of an 8-bp simple helix from data/sys-defs/simple-helix/
experiment_dir = Path(".")
top = jdna_top.from_oxdna_file(experiment_dir / "sys.top")
initial_positions = (
    jdna_traj.from_file(
        experiment_dir / "bound_relaxed.conf", ## THIS FAILS on GOOGLE COLAB GPU RUNTIME
        top.strand_counts,
        is_oxdna=False,
        n_processes=1
        )
    .states[0]
    .to_rigid_body()
)

FileNotFoundError: Topology file not found

### Set up the model: geometry + energy function for parameter optimization

In JaxDNA, energy functions comprise a functional form and a *configuration*, which specifies the parameters needed by the energy function. Some of these parameters may be ``dependent'' on others, such as parameters of a smoothing function that is constrained to be continuous and differentiable. The configuration and functional forms are put together by the ```jax_dna.energy.energy_fn_builder``` function.

INSERT PIC OF OXDNA POTENTIALS HERE

We'll load the default energy configuration, which contains the parameters of the oxDNA1 model, and the default energy functions, which are `Fene`, `BondedExcludedVolume`, `Stacking`, `UnbondedExcludedVolume`, `HydrogenBonding` `CrossStacking`, and `CoaxialStacking`. Finally, we load the default configurations associated with each of these functions, which include information on which parameters are free and which should be computed from the others (i.e. smoothing function parameters).

In [None]:
_, energy_config = dna1_energy.default_configs()
energy_fns = dna1_energy.default_energy_fns()
energy_fn_configs = dna1_energy.default_energy_configs()

Next, we set up a function to combine our energy functions and their configurations, as well as a transform function that contains the mapping from particle centers-of-mass (COM) to the specific sites used in the energy calculation. For example, in oxDNA, there are 3 "interaction sites" that exist a fixed distance from the nucleotide COM.

In [None]:
##DNA model geometry:
geometry = energy_config["geometry"]
transform_fn = functools.partial( #function that maps COMs onto rigid nucleotide sites
        dna1_energy.Nucleotide.from_rigid_body,
        com_to_backbone=geometry["com_to_backbone"],
        com_to_hb=geometry["com_to_hb"],
        com_to_stacking=geometry["com_to_stacking"],
)

energy_fn_builder_fn = jdna_energy.energy_fn_builder(
        energy_fns=energy_fns,
        energy_configs=energy_fn_configs,
        transform_fn=transform_fn,
    )

We vectorize this energy function so it can act on entire simulation trajectories, calculating the energy at each time step, using vmap:

In [None]:
def energy_fn_builder(params: jdna_types.Params) -> callable:
      return jax.vmap(
          lambda trajectory: energy_fn_builder_fn(params)(
              trajectory.rigid_body,
              seq=jnp.array(top.seq),
              bonded_neighbors=top.bonded_neighbors,
              unbonded_neighbors=top.unbonded_neighbors.T,
          )
      )

### Set up the simulator

We will run simulations with Jax-MD for simplicity (running with the standalone oxDNA code is supported, but would require a working oxDNA binary).

In [None]:
simulator = jdna_jaxmd.JaxMDSimulator(
      energy_configs=energy_fn_configs,
      energy_fns=energy_fns,
      topology=top,
      simulator_params=jdna_jaxmd.StaticSimulatorParams(
          seq=jnp.array(jnp.concat([top.seq[:8][::-1], top.seq[8:][::-1]])),
          mass=mass,
          bonded_neighbors=top.bonded_neighbors,
          checkpoint_every=opt_config["n_sim_steps"],
          dt=dt,
          kT=kT,
          gamma=gamma,
          ),
      space=jax_md.space.free(),
      transform_fn=transform_fn,
      simulator_init=jax_md.simulate.nvt_langevin,
      neighbors=jdna_jaxmd.NoNeighborList(unbonded_nbrs=top.unbonded_neighbors),
  )

Because jaxDNA supports advanced multi-objective optimization which may require multiple different simulation types to run concurrently, it is set up to write simulation trajectories to a file, trajectory.pkl. We need a simluation function wrapper that calls the jaxMD simulator, writes the output to a file, and returns the location of this file. QUESTION: Is this really necessary when I'm not using Ray and doing advanced optimizations? Simpler way to set this up for the tutorial?

In [None]:
cwd = Path(os.getcwd())
output_dir = cwd / "basic_trajectory"
trajectory_loc = output_dir / "trajectory.pkl"
if not output_dir.exists():
      output_dir.mkdir(parents=True, exist_ok=True)

def simulator_fn(
      params: jdna_types.Params,
      meta: jdna_types.MetaData,
  ) -> tuple[str, str]:
      in_key=meta["key"]
      in_key, subkey = jax.random.split(in_key)
      traj = simulator.run(params, initial_positions, opt_config["n_sim_steps"], subkey)
      jdna_tree.save_pytree(traj, trajectory_loc)
      return [trajectory_loc]

obs_trajectory = "trajectory"
key = jax.random.PRNGKey(1234)
trajectory_simulator = jdna_simulator.BaseSimulator(
      name="jaxmd-sim",
      fn=simulator_fn,
      exposes = [obs_trajectory],
      meta_data = {"key": key},
  )

### Set up the optimization:

First, we specify the parameters we want to optimize. We'll choose to float them all:

In [None]:
opt_params = []
for ec in energy_fn_configs:
  opt_params.append(
      ec.opt_params #if isinstance(ec, dna1_energy.StackingConfiguration) else {}
      )

##except these weights, which are used for sequence optimizations
for op in opt_params:
  if "ss_stack_weights" in op:
    del op["ss_stack_weights"]
  elif "ss_hb_weights" in op:
    del op["ss_hb_weights"]

Now, we'll define the observable we're interested in optimizing. JaxDNA supports multi-objective optimizations, but for simplicity, we'll consider a single objective: the helical pitch.

In [None]:
pitch_fn = pitch.PitchAngle(rigid_body_transform_fn=transform_fn,
          quartets=jnp.array([[[2,13],[3,12]],[[3,12],[4,11]],[[4,11],[5,10]]]), # we consider only the inner basepairs to eliminate the effect of fraying
          displacement_fn=jax_md.space.free()[0])

We now define a loss function:

In [None]:
def pitch_loss_fn(
    traj: jax_md.rigid_body.RigidBody,
    weights: jnp.ndarray,
    energy_model: jdna_energy.base.ComposedEnergyFunction,
) -> tuple[float, tuple[str, typing.Any]]:
    obs = pitch_fn(traj)
    expected_pitch = jnp.dot(weights, obs)
    loss = (expected_pitch - TARGET_PITCH)**2
    loss = jnp.sqrt(loss)
    return loss, (("pitch", expected_pitch), {})

as well as an objective function -- this tells jaxDNA we are going to use DiffTRE to compute our gradients:

In [None]:
pitch_objective = jdna_objective.DiffTReObjective(
    name = "DiffTRe",
    required_observables = [obs_trajectory],
    needed_observables = [obs_trajectory],
    logging_observables = ["loss", "pitch", "gradients"],
    grad_or_loss_fn = pitch_loss_fn,
    energy_fn_builder = energy_fn_builder,
    opt_params = opt_params,
    min_n_eff_factor = 0.95,
    beta = jnp.array(1/kT),
    n_equilibration_steps = 0,
)

Putting all of this together, we wrap our simulator choice (JaxMD), objective choice (DiffTRE), and an optimizer choice (we'll pick Adam) together using jaxDNA's `SimpleOptimizer` class:

In [None]:
opt = jdna_optimization.SimpleOptimizer(
    objective=pitch_objective,
    simulator=trajectory_simulator,
    optimizer = optax.adam(learning_rate=1e-3),
)

### Sensitivity analysis: which parameters are most important?

In [None]:
grads = pitch_objective.objective.calculate()

### Run the optimization

In [None]:
for i in tqdm(range(opt_config["n_opt_steps"])):
        opt_state, opt_params = opt.step(opt_params)

        if i % 1 == 0:
            log_values = pitch_objective.logging_observables()
            for (name, value) in log_values:
                if 'loss' in name.lower():
                    loss_value = value
                if 'pitch' in name.lower():
                    pitch_value = value
                print(f"{i}::{name}={value}")
            with open("pitch_losses.txt", 'a') as f:
                f.write(f"{i}\t{loss_value}\t{pitch_value}\n")
            #logger.log_metric("Loss", loss_value, i)
            #logger.log_metric("Pitch", pitch_value, i)

           # Store data points
            loss_x.append(i)
            loss_y.append(loss_value)
            pitch_x.append(i)
            pitch_y.append(2*jnp.pi/pitch_value)

            # Update the plots
            update_plots(TARGET_PITCH=TARGET_NUM_BPS_PER_TURN, N_ITERS=opt_config["n_opt_steps"])
            time.sleep(0.2)

        opt = opt.post_step(
            optimizer_state=opt_state,
            opt_params=opt_params,
        )

100%|██████████| 25/25 [08:20<00:00, 20.01s/it]
