-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pytree output structure mismatch error in backprop during vmap #48
Comments
Hmm, I think I agree that sounds like a plausible root cause. I'm still looking at this, but FWIW I've managed to reduce it to this MWE. Curiously, the type of import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
jax.config.update("jax_enable_x64", True)
CRASH = True
def rf(x, g):
return x[0], x[1] - g
def opt_2st_vec(g):
if CRASH:
x0 = (0.5, 0.5)
else:
x0 = jnp.array([0.5, 0.5])
solver = optx.Newton(atol=1e-8, rtol=1e-8)
solution = optx.root_find(rf, solver, x0, args=g)
return solution.value[0]
def loss_fn(x):
return jnp.sum(jax.vmap(opt_2st_vec)(x))
x = jr.uniform(jr.key(0), (128,))
jax.grad(loss_fn)(x) I'll keep poking at this, but let me know if you find anything sooner than that. |
Okay, got it! Looks like I've opened patrick-kidger/equinox#671 and patrick-kidger/lineax#84 to fix this. (Although the Lineax CI will fail as it can't see the updated Equinox PR.) |
Fantastic thanks for the quick fix & workaround. |
Hi @patrick-kidger and @FFroehlich, I might have a related issue. It persists even with the fixes in equinox@dev and lineax@vprim_transpose_symbolic_zeros. I'm vmapping a nonlinear solve (parameter estimation for ODEs across many individuals, each with their own parameter set). I get The error goes away if I use a for-loop, and it also goes away with a nonlinear solver that does not use gradients (Nelder-Mead). I'm working on an MWE, starting by adapting yours from above, @patrick-kidger. For added context: I have a nested hierarchical model composed of equinox modules, and I now want to optimize the final layer (population level) to leverage jax' SPMD capabilities. |
Here comes the MWE. import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
import equinox as eqx
import diffrax as dfx
jax.config.update("jax_enable_x64", True)
GRAD = True
VMAP = True
def dydt(t, y, args):
k = args
return -k * y
class Individual(eqx.Module):
term: dfx.ODETerm
solver: dfx.Tsit5
y0: float
t0: int
t1: int
dt0: int
saveat: dfx.SaveAt
def __init__(self, ode_system, y0):
self.term = dfx.ODETerm(ode_system)
self.solver = dfx.Tsit5()
self.y0 = y0
self.t0 = 0
self.t1 = 10
self.dt0 = 0.01
self.saveat = dfx.SaveAt(ts=jnp.arange(self.t0, self.t1, self.dt0))
def simulate(self, args):
sol = dfx.diffeqsolve(
self.term,
self.solver,
self.t0,
self.t1,
self.dt0,
self.y0,
args=args,
saveat=self.saveat,
adjoint=dfx.DirectAdjoint(),
)
return sol.ys
def estimate_param(self, initial_param, ydata, solver):
args = (self.simulate, ydata)
def residuals(param, args):
model, ydata = args
yfit = model(param)
res = ydata - yfit
return res
sol = optx.least_squares(
residuals,
solver,
initial_param,
args=args,
)
return sol.value
m = Individual(dydt, 10.)
def generate_data(individual_model): # Noise-free
k0s = (0.3, 0.5, 0.7) # Vary parameters
ydata = []
for k0 in k0s:
y = individual_model.simulate(k0)
ydata.append(y)
return jnp.array(ydata)
data = generate_data(m)
initial_k0 = 0.5 # Starting point for all
def run(initial_param, individual_model, individual_data):
if GRAD:
solver = optx.LevenbergMarquardt(rtol=1e-07, atol=1e-07)
else:
solver = optx.NelderMead(rtol=1e-07, atol=1e-07)
if VMAP:
get_params = jax.vmap(individual_model.estimate_param, in_axes=(None, 0, None))
params = get_params(initial_param, individual_data, solver)
else:
params = [individual_model.estimate_param(initial_param, y, solver) for y in individual_data]
return params
params = run(initial_k0, m, data)
params And this is how it behaves (with equinox@dev and lineax@vprim_transpose_symbolic_zeros). If (GRAD and VMAP): |
A few post scriptums:
|
Thank you for the issue! This was a fairly tricky one. Ultimately I think this is a sort-of bug (or at least a questionable design decision) in Can you give that branch a go on your actual (non-MWE) problem, and let me know if that fixes things? If so then I'll merge it. |
It works! Thank you so much for taking a look at this, even during the Easter holidays. It is very much appreciated! I want to add that I am new to the ecosystem and enjoy it very much, it is so well thought-through and documented. I hope I can start contributing something other than questions as I get to know it better :) |
Awesome stuff, I'm glad to hear it! I hope you enjoy using the ecosystem. :) On this basis I've just merged the fix, so it will appear in the next release of Equinox. |
I am running into
ValueError: pytree does not match out_structure
errors when computing gradients for functions where optimistix is called via vmap. The errors disappear when replacingjax.vmap
with an equivalent for loop. I have included a MWEbug_report.py
which can switch betweenjax.vmap
and for loops via theVMAP
variable. My first impression is that the implicit solve during backprop gets passed the wrong (unbatched?) input vector.MWE:
package versions:
The text was updated successfully, but these errors were encountered: