Skip to content

Commit

Permalink
Merge pull request google#458 from mblondel:weak_type_consistency
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544601673
  • Loading branch information
JAXopt authors committed Jun 30, 2023
2 parents fe57f15 + 924a1c1 commit 4e80f27
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 65 deletions.
25 changes: 18 additions & 7 deletions jaxopt/_src/armijo_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from jaxopt.tree_util import tree_add_scalar_mul, tree_l2_norm
from jaxopt.tree_util import tree_scalar_mul, tree_zeros_like
from jaxopt.tree_util import tree_add, tree_sub
from jaxopt._src.tree_util import tree_single_dtype

from jaxopt._src import base
from jaxopt._src import loop

Expand Down Expand Up @@ -244,15 +246,22 @@ def init_state(self, init_params, *args, **kwargs) -> ArmijoState:
velocity = tree_zeros_like(init_params)

if self.has_aux:
_, aux = self.fun(init_params, *args, **kwargs)
value, aux = self.fun(init_params, *args, **kwargs)
else:
value = self.fun(init_params, *args, **kwargs)
aux = None

params_dtype = tree_single_dtype(init_params)

return ArmijoState(iter_num=jnp.asarray(0),
error=jnp.asarray(jnp.inf),
value=jnp.asarray(jnp.inf),
# Error should be dtype-compatible with the parameters,
# not with the value, since the error is derived from the
# gradient, which lives in the same space as the params.
error=jnp.asarray(jnp.inf, dtype=params_dtype),
value=jnp.asarray(jnp.inf, value.dtype),
aux=aux,
stepsize=jnp.asarray(self.max_stepsize),
stepsize=jnp.asarray(self.max_stepsize,
dtype=params_dtype),
velocity=velocity)

def reset_stepsize(self, stepsize):
Expand All @@ -275,6 +284,8 @@ def update(self, params, state, *args, **kwargs) -> base.OptStep:
Returns:
(params, state)
"""
dtype = tree_single_dtype(params)

if self.pre_update:
params, state = self.pre_update(params, state, *args, **kwargs)

Expand All @@ -300,10 +311,10 @@ def update(self, params, state, *args, **kwargs) -> base.OptStep:
error = tree_l2_norm(grad, squared=False)

next_state = ArmijoState(iter_num=state.iter_num+1,
error=error,
value=f_next,
error=jnp.asarray(error, dtype=dtype),
value=jnp.asarray(f_next),
aux=aux,
stepsize=stepsize,
stepsize=jnp.asarray(stepsize, dtype=dtype),
velocity=next_velocity)

return base.OptStep(next_params, next_state)
Expand Down
6 changes: 4 additions & 2 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,10 @@ def run_iterator(self,
Returns:
(params, state)
"""
# TODO(mblondel): data-dependent initialization schemes need a batch.
state = self.init_state(init_params, *args, **kwargs)
# Some initializations need the data so we need to draw a batch from the
# iterator.
data = next(iterator)
state = self.init_state(init_params, *args, **kwargs, data=data)
params = init_params

# TODO(mblondel): try and benchmark lax.fori_loop with host_call for `next`.
Expand Down
5 changes: 3 additions & 2 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def init_state(self,
value=value,
grad=grad,
stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
error=jnp.asarray(jnp.inf),
error=jnp.asarray(jnp.inf, dtype=dtype),
H=jnp.eye(len(flat_init_params), dtype=dtype),
aux=aux)

Expand Down Expand Up @@ -233,11 +233,12 @@ def update(self,
new_H = _einsum('ij,jk,lk', w, state.H, w) + rho * ss
new_H = jnp.where(jnp.isfinite(rho), new_H, state.H)

error = tree_l2_norm(new_grad)
new_state = BfgsState(iter_num=state.iter_num + 1,
value=new_value,
grad=new_grad,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(new_grad),
error=jnp.asarray(error, dtype=dtype),
H=new_H,
aux=new_aux)

Expand Down
6 changes: 4 additions & 2 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def init_state(self,
(value, aux), grad = self._value_and_grad_with_aux(init_params, *args, **kwargs)
return LbfgsState(value=value,
grad=grad,
error=jnp.asarray(jnp.inf),
error=jnp.asarray(jnp.inf, dtype=dtype),
**state_kwargs,
aux=aux,
failed_linesearch=jnp.asarray(False))
Expand Down Expand Up @@ -372,11 +372,13 @@ def update(self,
else:
gamma = jnp.array(1.0)

dtype = tree_single_dtype(params)
error = tree_l2_norm(new_grad)
new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
grad=new_grad,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(new_grad),
error=jnp.asarray(error, dtype=dtype),
s_history=s_history,
y_history=y_history,
rho_history=rho_history,
Expand Down
14 changes: 9 additions & 5 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def init_state(self,
*args,
**kwargs)

dtype = tree_single_dtype(init_params)

return NonlinearCGState(iter_num=jnp.asarray(0),
stepsize=jnp.asarray(self.max_stepsize),
error=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
error=jnp.asarray(jnp.inf, dtype=dtype),
value=value,
grad=grad,
descent_direction=tree_scalar_mul(-1.0, grad),
Expand Down Expand Up @@ -229,9 +231,11 @@ def update(self,
new_descent_direction = tree_add_scalar_mul(tree_scalar_mul(-1, new_grad),
new_beta,
descent_direction)
error = tree_l2_norm(grad)
dtype = tree_single_dtype(params)
new_state = NonlinearCGState(iter_num=state.iter_num + 1,
stepsize=jnp.asarray(new_stepsize),
error=tree_l2_norm(grad),
stepsize=jnp.asarray(new_stepsize, dtype=dtype),
error=jnp.asarray(error, dtype=dtype),
value=new_value,
grad=new_grad,
descent_direction=new_descent_direction,
Expand Down Expand Up @@ -272,7 +276,7 @@ def __post_init__(self):
)

self.run_ls = linesearch_solver.run

if self.condition is not None:
warnings.warn("Argument condition is deprecated", DeprecationWarning)
if self.decrease_factor is not None:
Expand Down
15 changes: 10 additions & 5 deletions jaxopt/_src/optax_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,16 @@ def init_state(self,
opt_state = self.opt.init(init_params)

if self.has_aux:
_, aux = self.fun(init_params, *args, **kwargs)
value, aux = self.fun(init_params, *args, **kwargs)
else:
value = self.fun(init_params, *args, **kwargs)
aux = None

params_dtype = tree_util.tree_single_dtype(init_params)

return OptaxState(iter_num=jnp.asarray(0),
value=jnp.asarray(jnp.inf),
error=jnp.asarray(jnp.inf),
value=jnp.asarray(jnp.inf, value.dtype),
error=jnp.asarray(jnp.inf, dtype=params_dtype),
aux=aux,
internal_state=opt_state)

Expand Down Expand Up @@ -144,9 +147,11 @@ def update(self,
params = self._apply_updates(params, delta)

# Computes optimality error before update to re-use grad evaluation.
dtype = tree_util.tree_single_dtype(params)
error = tree_util.tree_l2_norm(grad)
new_state = OptaxState(iter_num=state.iter_num + 1,
error=tree_util.tree_l2_norm(grad),
value=value,
error=jnp.asarray(error, dtype=dtype),
value=jnp.asarray(value),
aux=aux,
internal_state=opt_state)
return base.OptStep(params=params, state=new_state)
Expand Down
20 changes: 12 additions & 8 deletions jaxopt/_src/polyak_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,22 @@ def init_state(self,
state
"""
if self.has_aux:
_, aux = self.fun(init_params, *args, **kwargs)
value, aux = self.fun(init_params, *args, **kwargs)
else:
value = self.fun(init_params, *args, **kwargs)
aux = None

if self.momentum == 0:
velocity = None
else:
velocity = tree_zeros_like(init_params)

param_dtype = tree_single_dtype(init_params)
params_dtype = tree_single_dtype(init_params)

return PolyakSGDState(iter_num=jnp.asarray(0),
error=jnp.asarray(jnp.inf),
value=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(1.0, dtype=param_dtype),
error=jnp.asarray(jnp.inf, dtype=params_dtype),
value=jnp.asarray(jnp.inf, dtype=value.dtype),
stepsize=jnp.asarray(1.0, dtype=params_dtype),
aux=aux,
velocity=velocity)

Expand All @@ -173,6 +174,8 @@ def update(self,
Returns:
(params, state)
"""
dtype = tree_single_dtype(params)

if self.pre_update:
params, state = self.pre_update(params, state, *args, **kwargs)

Expand All @@ -193,11 +196,12 @@ def update(self,
tree_scalar_mul(stepsize, grad))
new_params = tree_add(params, new_velocity)

error = jnp.sqrt(grad_sqnorm)
new_state = PolyakSGDState(iter_num=state.iter_num + 1,
error=jnp.sqrt(grad_sqnorm),
error=jnp.asarray(error, dtype=dtype),
velocity=new_velocity,
value=value,
stepsize=stepsize,
value=jnp.asarray(value),
stepsize=jnp.asarray(stepsize, dtype=dtype),
aux=aux)
return base.OptStep(params=new_params, state=new_state)

Expand Down
21 changes: 14 additions & 7 deletions jaxopt/_src/proximal_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jaxopt._src.tree_util import tree_l2_norm
from jaxopt._src.tree_util import tree_sub
from jaxopt._src.tree_util import tree_vdot
from jaxopt._src.tree_util import tree_single_dtype


def fista_line_search(
Expand Down Expand Up @@ -183,17 +184,19 @@ def init_state(self,
else:
aux = None

dtype = tree_single_dtype(init_params)

if self.acceleration:
state = ProxGradState(iter_num=jnp.asarray(0),
velocity=init_params,
t=jnp.asarray(1.0),
stepsize=jnp.asarray(1.0),
error=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(1.0, dtype=dtype),
error=jnp.asarray(jnp.inf, dtype=dtype),
aux=aux)
else:
state = ProxGradState(iter_num=jnp.asarray(0),
stepsize=jnp.asarray(1.0),
error=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(1.0, dtype=dtype),
error=jnp.asarray(jnp.inf, dtype=dtype),
aux=aux)

return state
Expand Down Expand Up @@ -238,6 +241,7 @@ def _iter(self,
return next_x, next_stepsize

def _update(self, x, state, hyperparams_prox, args, kwargs):
dtype = tree_single_dtype(x)
iter_num = state.iter_num
stepsize = state.stepsize
(x_fun_val, aux), x_fun_grad = self._value_and_grad_with_aux(x, *args,
Expand All @@ -246,11 +250,13 @@ def _update(self, x, state, hyperparams_prox, args, kwargs):
stepsize, hyperparams_prox, args, kwargs)
error = self._error(tree_sub(next_x, x), next_stepsize)
next_state = ProxGradState(iter_num=iter_num + 1,
stepsize=next_stepsize,
error=error, aux=aux)
stepsize=jnp.asarray(next_stepsize, dtype=dtype),
error=jnp.asarray(error, dtype=dtype),
aux=aux)
return base.OptStep(params=next_x, state=next_state)

def _update_accel(self, x, state, hyperparams_prox, args, kwargs):
dtype = tree_single_dtype(x)
iter_num = state.iter_num
y = state.velocity
t = state.t
Expand All @@ -264,7 +270,8 @@ def _update_accel(self, x, state, hyperparams_prox, args, kwargs):
next_y = tree_add_scalar_mul(next_x, (t - 1) / next_t, diff_x)
next_error = self._error(diff_x, next_stepsize)
next_state = ProxGradState(iter_num=iter_num + 1, velocity=next_y, t=next_t,
stepsize=next_stepsize, error=next_error,
stepsize=jnp.asarray(next_stepsize, dtype=dtype),
error=jnp.asarray(next_error, dtype=dtype),
aux=aux)
return base.OptStep(params=next_x, state=next_state)

Expand Down
18 changes: 0 additions & 18 deletions jaxopt/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,6 @@ def lsq_linear_cube_osp_jac(X, y, l, eps=1e-5, tol=1e-10, max_iter=None):
lsq_linear_cube_osp(X, y, l - eps, tol, max_iter)) / (2 * eps)


def check_states_have_same_types(state1, state2):
if len(state1._fields) != len(state2._fields):
raise ValueError("state1 and state2 should have the same number of "
"attributes.")

for attr1, attr2 in zip(state1._fields, state2._fields):
if attr1 != attr2:
raise ValueError("Attribute names do not agree: %s and %s." % (attr1,
attr2))

type1 = type(getattr(state1, attr1)).__name__
type2 = type(getattr(state2, attr2)).__name__

if type1 != type2:
raise ValueError("Attribute '%s' has different types in state1 and "
"state2: %s vs. %s" % (attr1, type1, type2))


# Test utilities copied from JAX core so we don't depend on their private API.

_dtype_to_32bit_dtype = {
Expand Down
5 changes: 3 additions & 2 deletions tests/armijo_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def dataset_loader(X, y, n_iter):
tol = 1e-3
opt = ArmijoSGD(fun=fun, reset_option='goldstein', maxiter=1000, tol=tol)
iterable = dataset_loader(X, y, n_iter=200)
state = opt.init_state(params, l2reg=l2reg)
state = opt.init_state(params, l2reg=l2reg, data=(X, y))
@jax.jit
def jitted_update(params, state, data):
return opt.update(params, state, l2reg=l2reg, data=data)
Expand Down Expand Up @@ -169,7 +169,8 @@ def dataset_loader(X, y, n_iter):
pytree_init = (W_init, b_init)

tol = 3e-1
opt = ArmijoSGD(fun=fun, maxiter=10, tol=tol) # few iterations due to speed issues
# few iterations due to speed issues
opt = ArmijoSGD(fun=fun, maxiter=10, tol=tol)
iterable = dataset_loader(X, y, n_iter=200)
params, _ = opt.run_iterator(pytree_init, iterable, l2reg=l2reg)

Expand Down
Loading

0 comments on commit 4e80f27

Please sign in to comment.