-
-
Notifications
You must be signed in to change notification settings - Fork 163
Open
Labels
questionUser queriesUser queries
Description
I want to calculate the derivatives of the output of a neural ODE w.r.t. the input. However, I met with memory issues. I think the memory-saving method RecursiveCheckpointAdjoint doesn't support forward-mode automatic differentiation (like jax.jvp or jax.jacfwd). In my case, I want to calculate the derivative of f (which is the output of Neural ODE) w.r.t. x and t (which are the inputs) as below:
def f_and_derivs_fast_vec(variables, apply_fn, xt, t1):
"""
xt: (N,2) with columns [x, t]
returns f, f_x, f_t, f_xx each (N, D)
"""
xt = jnp.asarray(xt)
def f_vec(z): # z: (2,) -> (D,)
f = apply_fn(variables, z[None, :], t1=t1) # model returns (1,D)
return f[0] # (D,)
ex = jnp.array([1.0, 0.0]) # d/dx
et = jnp.array([0.0, 1.0]) # d/dt
def one_point(z):
f = f_vec(z) # (D,)
_, fx = jax.jvp(f_vec, (z,), (ex,)) # (D,)
_, ft = jax.jvp(f_vec, (z,), (et,)) # (D,)
def gx(y): return jax.jvp(f_vec, (y,), (ex,))[1] # f_x(y)
_, fxx = jax.jvp(gx, (z,), (ex,)) # (D,)
return f, fx, ft, fxx
f, fx, ft, fxx = jax.vmap(one_point)(xt)
return f, fx, ft, fxx
``
I define my Neural ODE as below:
```python
class Func(eqx.Module):
out_scale: jax.Array
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.out_scale = jnp.array(1.0)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.swish,
final_activation=jax.nn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.out_scale * self.mlp(y)
class NeuralODE(eqx.Module):
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, width_size, depth, key=key)
def __call__(self, t1, y0):
y0 = jnp.asarray(y0).reshape(-1) # (D,)
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=0.0,
t1=t1,
dt0=1e-3,
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=3e-3, atol=3e-6),
saveat=diffrax.SaveAt(t1=True, dense=False),
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
ys = jnp.asarray(solution.ys).reshape(-1) # (D,)
return ysI define my model as below:
class PINN(nn.Module):
n_nodes: int
n_layers: int = 1
node_data_size: int = 512 # Size of the data input to the NODE
node_width: int = 64
node_depth: int = 2
def setup(self):
self.hidden_layers = [nn.Dense(self.n_nodes, kernel_init=jax.nn.initializers.he_uniform())
for _ in range(self.n_layers)]
self.integrator = NeuralODE(data_size=self.node_data_size, width_size=self.node_width, depth=self.node_depth, key=jr.PRNGKey(0))
def encode_input(self, inputs):
x = inputs
for idx, dense in enumerate(self.hidden_layers):
x = dense(x)
if idx == 0:
x = 2 * jnp.pi * x
x = jnp.sin(x)
return x
@nn.compact
def __call__(self, inputs, t1=0):
xt = inputs # shape (N, 1)
f_raw = self.encode_input(xt)
f_last = jax.vmap(self.integrator, in_axes=(None, 0))(t1, f_raw) # (N, 512)
return f_lastThe problem I met is
2025-09-16 19:18:24.767882: E external/xla/xla/service/slow_operation_alarm.cc:65] ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:23:11.188537: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 33.78GiB (36271006133 bytes) by rematerialization; only reduced to 62.50GiB (67109664958 bytes), down from 62.50GiB (67109716246 bytes) originally 2025-09-16 19:23:34.487702: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 7m9.719906955s ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:24:00.325370: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 64.71GiB (rounded to 69477018624)requested by op 2025-09-16 19:24:00.325549: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________ E0916 19:24:00.325599 2586479 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 69477018512 bytes.
Do you have any suggestions of solving this problem? I really appreciate your help and time!
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries