Skip to content

Commit

Permalink
Merge pull request #1171 from pybamm-team/issue-1170-jax-arraylike
Browse files Browse the repository at this point in the history
#1170 change jnp.max and min to maximum and minimum
  • Loading branch information
tlestang committed Sep 16, 2020
2 parents 6726d60 + 866a01b commit 87dc41a
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
state['rtol'] = rtol
state['M'] = mass
EPS = jnp.finfo(y0.dtype).eps
state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5))))
state['newton_tol'] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol ** 0.5))

scale_y0 = atol + rtol * jnp.abs(y0)
y0, not_converged = _select_initial_conditions(
Expand Down Expand Up @@ -325,7 +325,7 @@ def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale)**2))
order = 1
h1 = h0 * d2 ** (-1 / (order + 1))
return jnp.min((100 * h0, h1))
return jnp.minimum(100 * h0, h1)


def _predict(state, D):
Expand Down Expand Up @@ -559,7 +559,7 @@ def _prepare_next_step_order_change(state, d, y, n_iter):
max_index = jnp.argmax(factors)
order += max_index - 1

factor = jnp.min((MAX_FACTOR, safety * factors[max_index]))
factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index])

new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor)
return new_state
Expand Down Expand Up @@ -599,9 +599,9 @@ def while_body(while_state):
state
)

#if not_converged * updated_jacobian:
# if not_converged * updated_jacobian:
# print('not converged, update step size by 0.3')
#if not_converged * (updated_jacobian == False):
# if not_converged * (updated_jacobian == False):
# print('not converged, update jacobian')

# if not converged and jacobian not updated, then update the jacobian and try
Expand All @@ -626,11 +626,11 @@ def while_body(while_state):
error_norm = rms_norm(error / scale_y)

# calculate optimal step size factor as per eq 2.46 of [2]
factor = jnp.max((MIN_FACTOR,
safety *
error_norm ** (-1 / (state.order + 1))))
factor = jnp.maximum(MIN_FACTOR,
safety *
error_norm ** (-1 / (state.order + 1)))

#if converged * (error_norm > 1):
# if converged * (error_norm > 1):
# print('converged, but error is too large',error_norm, factor, d, scale_y)

(state, step_accepted) = tree_multimap(
Expand Down

0 comments on commit 87dc41a

Please sign in to comment.