-
Notifications
You must be signed in to change notification settings - Fork 42
Closed
Description
Hey Patrick, thank you for the very nice package(s). I am extensively using equinox and optimistix in my current project for meta-learning.
When I try to take the gradient through the solver after applying vmap, I get the error Differentiation rule for 'unvmap_max' not implemented. I am trying to isolate the issue and get a minimum working example but it is taking a bit too long and so I am just posting it here to see if I am making some simple mistake. The pseudo code looks something like
def single_example_outer_loss(task_params, single_task_data):
....
# call bisection to restrict the model output to have a particular mean
return loss
vmapped_outer_loss = eqx.filter_vmap(single_example_outer_loss)
def average_meta_loss(task_params_batched, multi_task_data):
return eqx.filter_value_and_grad(vmapped_outer_loss)(task_params_batched, multi_task_data)
I am trying to isolate the issue and provide a MWE
Metadata
Metadata
Assignees
Labels
No labels