Skip to content

Out of Memory: How to reduce the memory cost for calculating the derivatives wrt the input #691

@zhlfmyzhh

Description

@zhlfmyzhh

I want to calculate the derivatives of the output of a neural ODE w.r.t. the input. However, I met with memory issues. I think the memory-saving method RecursiveCheckpointAdjoint doesn't support forward-mode automatic differentiation (like jax.jvp or jax.jacfwd). In my case, I want to calculate the derivative of f (which is the output of Neural ODE) w.r.t. x and t (which are the inputs) as below:

def f_and_derivs_fast_vec(variables, apply_fn, xt, t1):
    """
    xt: (N,2) with columns [x, t]
    returns f, f_x, f_t, f_xx each (N, D)
    """
    xt = jnp.asarray(xt)

    def f_vec(z):                                # z: (2,) -> (D,)
        f = apply_fn(variables, z[None, :], t1=t1)   # model returns (1,D)
        return f[0]                                   # (D,)

    ex = jnp.array([1.0, 0.0])  # d/dx
    et = jnp.array([0.0, 1.0])  # d/dt

    def one_point(z):
        f     = f_vec(z)                             # (D,)
        _, fx  = jax.jvp(f_vec, (z,), (ex,))         # (D,)
        _, ft  = jax.jvp(f_vec, (z,), (et,))         # (D,)
        def gx(y): return jax.jvp(f_vec, (y,), (ex,))[1]  # f_x(y)
        _, fxx = jax.jvp(gx, (z,), (ex,))            # (D,)
        return f, fx, ft, fxx

    f, fx, ft, fxx = jax.vmap(one_point)(xt)
    return f, fx, ft, fxx
``

I define my Neural ODE as below:

```python
class Func(eqx.Module):
    out_scale: jax.Array
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.out_scale = jnp.array(1.0)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.swish,
            final_activation=jax.nn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.out_scale * self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, t1, y0):
        y0 = jnp.asarray(y0).reshape(-1)  # (D,)
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=0.0,
            t1=t1,
            dt0=1e-3,
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=3e-3, atol=3e-6),
            saveat=diffrax.SaveAt(t1=True, dense=False),
            adjoint=diffrax.RecursiveCheckpointAdjoint(),
        )
        ys = jnp.asarray(solution.ys).reshape(-1)  # (D,)

        return ys

I define my model as below:

class PINN(nn.Module):
    n_nodes: int
    n_layers: int = 1
    node_data_size: int = 512  # Size of the data input to the NODE
    node_width: int = 64
    node_depth: int = 2

    def setup(self):
        self.hidden_layers = [nn.Dense(self.n_nodes, kernel_init=jax.nn.initializers.he_uniform())
                              for _ in range(self.n_layers)]
        self.integrator = NeuralODE(data_size=self.node_data_size, width_size=self.node_width, depth=self.node_depth, key=jr.PRNGKey(0))

    def encode_input(self, inputs):
        x = inputs
        for idx, dense in enumerate(self.hidden_layers):
            x = dense(x)
            if idx == 0:
                x = 2 * jnp.pi * x
            x = jnp.sin(x)
        return x

    @nn.compact
    def __call__(self, inputs, t1=0):
        xt = inputs  # shape (N, 1)

        f_raw = self.encode_input(xt)
        f_last = jax.vmap(self.integrator, in_axes=(None, 0))(t1, f_raw) # (N, 512)

        return f_last

The problem I met is

2025-09-16 19:18:24.767882: E external/xla/xla/service/slow_operation_alarm.cc:65] ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:23:11.188537: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 33.78GiB (36271006133 bytes) by rematerialization; only reduced to 62.50GiB (67109664958 bytes), down from 62.50GiB (67109716246 bytes) originally 2025-09-16 19:23:34.487702: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 7m9.719906955s ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:24:00.325370: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 64.71GiB (rounded to 69477018624)requested by op 2025-09-16 19:24:00.325549: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________ E0916 19:24:00.325599 2586479 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 69477018512 bytes.

Do you have any suggestions of solving this problem? I really appreciate your help and time!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions