Skip to content

Training an Event Function #208

@rrm45

Description

@rrm45

Hi,

I've been trying to train a drift and event function but the parameters of my event function are not changing. Here's a simplified portion of the code I've been working with that only deals with the event function:
loss_fn = nn.MSELoss(reduction='sum')
func = ODEFunc().to(device).double()
event = ODEEvent().to(device).double()
optimizer = optim.Adam(event.parameters(), lr=0.001)
for itr in range(iters):
optimizer.zero_grad()
event_t, pred = odeint_event(func, v0, t0, event_fn=event, odeint_interface=odeint_adjoint, method='bosh3', atol=1e-6)
loss = loss_fn(event_t, true_t)
loss.backward()
optimizer.step()

The parameters of event never change between the iterations and list(event.parameters())[0].grad is always None. How do I get the gradients w.r.t. the time points so my event function can learn? Any help is appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions