-
Notifications
You must be signed in to change notification settings - Fork 980
Description
Hi,
I've been trying to train a drift and event function but the parameters of my event function are not changing. Here's a simplified portion of the code I've been working with that only deals with the event function:
loss_fn = nn.MSELoss(reduction='sum')
func = ODEFunc().to(device).double()
event = ODEEvent().to(device).double()
optimizer = optim.Adam(event.parameters(), lr=0.001)
for itr in range(iters):
optimizer.zero_grad()
event_t, pred = odeint_event(func, v0, t0, event_fn=event, odeint_interface=odeint_adjoint, method='bosh3', atol=1e-6)
loss = loss_fn(event_t, true_t)
loss.backward()
optimizer.step()
The parameters of event never change between the iterations and list(event.parameters())[0].grad is always None. How do I get the gradients w.r.t. the time points so my event function can learn? Any help is appreciated.