Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correction for Adjoint Gradient Calculations #35

Closed
aarcher07 opened this issue Jul 12, 2022 · 4 comments
Closed

Correction for Adjoint Gradient Calculations #35

aarcher07 opened this issue Jul 12, 2022 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@aarcher07
Copy link

aarcher07 commented Jul 12, 2022

When computing the gradient with respect to initial conditions, I find that the adjoint and fwd gradient computations differ slightly. Following equation 14 of CVODES manual, the discrepancy is in the adjoint equation and can be corrected by adding the constant, -np.matmul(sens0, lambda_out - grads[0, :]), to grad_out from solve_backward. sens0 is the initial sensitivities, lambda_out is the adjoint variables at time 0 and grads[0, :]) is the derivative of the likelihood with respect to the state variables at time 0.

I have to adjust lambda at 0 by -grads[0, :] because line 691 of sunode/solver.py appears to not loop over the initial time point of grads.

Does SolveODEAdjointBackward in sunode/sunode/wrappers/as_aesara.py implement a similar correction constant when computing the gradient wrt initial conditions?

I attached some code below as an example. It is based on the example script of the readMe. It computes the gradient of the sum of Hares and Lynx at time 0, 5 and 10 wrt alpha, beta and hares0 where hares0 is log10(Hares(0)).

import numpy as np
import sunode
import sunode.wrappers.as_aesara
import pymc as pm
import matplotlib.pyplot as plt
lib = sunode._cvodes.lib


def lotka_volterra(t, y, p):
    """Right hand side of Lotka-Volterra equation.

    All inputs are dataclasses of sympy variables, or in the case
    of non-scalar variables numpy arrays of sympy variables.
    """
    return {
        'hares': p.alpha * y.hares - p.beta * y.lynx * y.hares,
        'lynx': p.delta * y.hares * y.lynx - p.gamma * y.lynx,
    }

# initialize problem
problem = sunode.symode.SympyProblem(
    params={
        # We need to specify the shape of each parameter.
        # Any empty tuple corresponds to a scalar value.
        'alpha': (),
        'beta': (),
        'gamma': (),
        'delta': (),
        'hares0': ()
    },
    states={
        # The same for all state variables
        'hares': (),
        'lynx': (),
    },
    rhs_sympy=lotka_volterra,
    derivative_params=[
        # We need to specify with respect to which variables
        # gradients should be computed.
        ('alpha',),
        ('beta',),
        ('hares0',),
    ],
)

tvals = np.linspace(0, 10, 3)

y0 = np.zeros((), dtype=problem.state_dtype)
y0['hares'] = 1e0
y0['lynx'] = 0.1
params_dict = {
    'alpha': 0.1,
    'beta': 0.2,
    'gamma': 0.3,
    'delta': 0.4,
    'hares0': 1e0
}


sens0 = np.zeros((3, 2))
sens0[2,0] = np.log(10)*1e0

solver = sunode.solver.Solver(problem, solver='BDF', sens_mode='simultaneous')
yout, sens_out = solver.make_output_buffers(tvals)


# gradient via fwd senstivity
solver.set_params_dict(params_dict)
output = solver.make_output_buffers(tvals)
solver.solve(t0=0, tvals=tvals, y0=y0, y_out=yout, sens0=sens0, sens_out=sens_out)

grad_out_fwd = [ sens_out[:,j,:].sum() for j in range(3)]
print(grad_out_fwd)

# gradient via adj senstivity
solver = sunode.solver.AdjointSolver(problem, solver='BDF')
solver.set_params_dict({
    'alpha': 0.1,
    'beta': 0.2,
    'gamma': 0.3,
    'delta': 0.4,
    'hares0': 1e0
})
tvals_expanded = np.linspace(0, 10, 21)
yout, grad_out, lambda_out = solver.make_output_buffers(tvals_expanded)
lib.CVodeSetMaxNumSteps(solver._ode, 10000)
solver.solve_forward(t0=0, tvals=tvals, y0=y0, y_out=yout)
grads = np.zeros_like(yout)
grads[::10,:] = 1
solver.solve_backward(t0=tvals_expanded[-1], tend=tvals_expanded[0], tvals=tvals_expanded[1:-1],
                      grads=grads, grad_out=grad_out, lamda_out=lambda_out)
grad_out_adj = -np.matmul(sens0, lambda_out  -grads[0, :]) + grad_out
print(grad_out_adj)
@michaelosthege michaelosthege added the bug Something isn't working label Jul 15, 2022
@michaelosthege
Copy link
Member

Thanks for opening such a detailed issue!
I edited the formatting in your comment and added a link to that line.


I'm not familiar with the implementation, but is this a bug due to t_intervals and grads having different lengths and the for iterator never reaches the None element in the reversed(grads)?
Because inside that loop there's this if grad is not None: which appears to have been written for this t=0 element..

@aarcher07
Copy link
Author

Great! Thank you for editing my post.

There is also another issue. At sufficiently small time evaluations, the gradient computations via the adjoint equations are inaccurate when compared to those of forward sensitivities.

Following the example above, if I evaluate the adjoint equation at time = 0, 5, 10 and grads = np.ones_like(yout) then I get that

  • grad_out_fwd = [29.633367875233063, -8.63361922455043, 10.2485995824757]
  • grad_out_adj = [ 8.08665182 -1.22041063 8.53419337].

However as in my original post, if I evaluate the adjoint equations at np.linspace(0, 10, 21), which includes time = 0, 5, 10, and zeros-pad grads at the time not equal to 0, 5, 10, then I get

  • grad_out_fwd = [29.633367875233063, -8.63361922455043, 10.2485995824757]
  • grad_out_adj = [27.71999675 -7.42137746 10.507334 ].

Thank you for looking to these issues!

@aseyboldt
Copy link
Member

aseyboldt commented Nov 26, 2022

@aarcher07 Thank you for reporting this, and sorry for the very late reply...

I think the problem you are seeing is due to a small mistake in the arguments to solve_backward. If I replace it by this, I get the same results as the forward solver:

# Instead of this
#solver.solve_backward(t0=tvals_expanded[-1], tend=tvals_expanded[0], tvals=tvals_expanded[1:-1],
#                      grads=grads, grad_out=grad_out, lamda_out=lambda_out)

# It should be this
solver.solve_backward(
    t0=tvals_expanded[-1],
    tend=tvals_expanded[0],
    tvals=tvals_expanded,
    grads=grads,
    grad_out=grad_out,
    lamda_out=lambda_out
)

grad_out_adj = -sens0 @ lambda_out + grad_out
print(grad_out_adj)

# Output

# from forward
# [29.633367875233063, -8.63361922455043, 10.2485995824757]

# from adjoint
# [29.63336772 -8.63361915 10.24859955]

The problem is that by passing in tvals=tvals_expanded[1:-1] we actually don't use the first two entries of grads, and the time points for those gradients don't match the correct tvals anymore.

@aseyboldt
Copy link
Member

I'm closing this because I think it was a problem in the example code, but feel free to reopen or comment if you don't agree or have questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants