Skip to content

Commit

Permalink
Merge pull request google#464 from phinate:fix-lbfgsb-grad
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545594782
  • Loading branch information
JAXopt authors committed Jul 5, 2023
2 parents 9afa6f7 + a5aada2 commit 22a4af9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
14 changes: 13 additions & 1 deletion jaxopt/_src/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# [2] J. Nocedal and S. Wright. Numerical Optimization, second edition.

import dataclasses
import inspect
import warnings
from typing import Any, Callable, NamedTuple, Optional, Union

Expand Down Expand Up @@ -558,15 +559,26 @@ def _value_and_grad_fun(self, params, *args, **kwargs):
params = params.params
(value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
return value, grad

def _grad_fun(self, params, *args, **kwargs):
return self._value_and_grad_fun(params, *args, **kwargs)[1]

def __post_init__(self):
_, _, self._value_and_grad_with_aux = base._make_funs_with_aux(
fun=self.fun,
value_and_grad=self.value_and_grad,
has_aux=self.has_aux,
)

# Sets up reference signature.
fun = getattr(self.fun, "subfun", self.fun)
signature = inspect.signature(fun)
parameters = list(signature.parameters.values())
new_param = inspect.Parameter(name="bounds",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
parameters.insert(1, new_param)
self.reference_signature = inspect.Signature(parameters)

self.reference_signature = self.fun

jit, unroll = self._get_loop_options()
linesearch_solver = _setup_linesearch(
Expand Down
15 changes: 15 additions & 0 deletions tests/lbfgsb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ def fun(x):

self.assertEqual(N_CALLS, n_iter + 1)

def test_grad_with_bounds(self):
# Test that the gradient is correct when bounds are specified by keyword.
# Pertinent to issue #463.
def pipeline(x, init_pars, bounds, data):
def fit_objective(pars, data, x):
return -jax.scipy.stats.norm.logpdf(pars, loc=data*x, scale=1.0)
solver = LBFGSB(fun=fit_objective, implicit_diff=True, maxiter=500, tol=1e-6)
return solver.run(init_pars, bounds=bounds, data=data, x=x)[0]

grad_fn = jax.grad(pipeline)
data = jnp.array(1.5)
res = grad_fn(0.5, jnp.array(0.0), (jnp.array(0.0), jnp.array(10.0)), data)
self.assertEqual(res, data)



if __name__ == "__main__":
absltest.main()

0 comments on commit 22a4af9

Please sign in to comment.