Skip to content

Error with bisection + vmap #104

@SNMS95

Description

@SNMS95

Hey Patrick, thank you for the very nice package(s). I am extensively using equinox and optimistix in my current project for meta-learning.

When I try to take the gradient through the solver after applying vmap, I get the error Differentiation rule for 'unvmap_max' not implemented. I am trying to isolate the issue and get a minimum working example but it is taking a bit too long and so I am just posting it here to see if I am making some simple mistake. The pseudo code looks something like

def single_example_outer_loss(task_params, single_task_data):
    ....
    # call bisection to restrict the model output to have a particular mean
    return loss
    
    
vmapped_outer_loss = eqx.filter_vmap(single_example_outer_loss)

def average_meta_loss(task_params_batched, multi_task_data):
       return eqx.filter_value_and_grad(vmapped_outer_loss)(task_params_batched, multi_task_data)

I am trying to isolate the issue and provide a MWE

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions