Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
yoelcortes committed Jan 9, 2021
1 parent 75085c5 commit 6b36dbc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion flexsolve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
from .problem import *
from .profiler import *

__version__ = '0.4.2'
__version__ = '0.4.3'
12 changes: 6 additions & 6 deletions flexsolve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def aitken_iter(x, gg, dxg, dgg_g):
return scalar_aitken_iter(x, gg, dxg, dgg_g)

@overload(wegstein_iter)
def jit_wegstein_iter(x, dx, g1, g0):
def jit_wegstein_iter(x, dx, g1, g0): # pragma: no cover
if isinstance(x, types.Array) and x.ndim:
return array_wegstein_iter
else:
return scalar_wegstein_iter

@overload(aitken_iter)
def jit_aitken_iter(x, gg, dxg, dgg_g):
def jit_aitken_iter(x, gg, dxg, dgg_g): # pragma: no cover
if isinstance(x, types.Array) and x.ndim:
return array_aitken_iter
else:
Expand Down Expand Up @@ -147,15 +147,15 @@ def IQ_iter(y0, y1, y2, x0, x1, x2, dx, df0, xlast):
return x

@njitable(cache=True)
def raise_iter_error(): # pragma: no cover
def raise_iter_error():
raise RuntimeError('maximum number of iterations exceeded; root could not be solved')

@njitable(cache=True)
def raise_tol_error(): # pragma: no cover
raise RuntimeError('minimum tolerance reached; root could not be solved')

@njitable(cache=True)
def raise_convergence_error(): # pragma: no cover
def raise_convergence_error():
raise RuntimeError('objective function either oscillates or diverges from solution; root could not be solved')

@njitable(cache=True)
Expand Down Expand Up @@ -183,7 +183,7 @@ def fixedpoint_converged(dx, xtol):
return scalar_fixedpoint_converged(dx, xtol)

@overload(fixedpoint_converged)
def jit_fixedpoint_converged(dx, xtol):
def jit_fixedpoint_converged(dx, xtol): # pragma: no cover
if isinstance(dx, types.Array) and dx.ndim:
return array_fixedpoint_converged
else:
Expand All @@ -196,5 +196,5 @@ def mean(x):
def scalar_mean(x): return x

@overload(mean)
def jit_mean(x):
def jit_mean(x): # pragma: no cover
return np.mean if isinstance(x, types.Array) and x.ndim else scalar_mean
14 changes: 9 additions & 5 deletions tests/test_fixedpoint_array_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ def test_fixedpoint_array_solvers():

assert p.sizes() == {'Wegstein': 5, 'Aitken': 5, 'Lstsq': 22, 'Fixed point': 194}

flx.speed_up()
solution = flx.wegstein(f, feed, convergenceiter=4, xtol=5e-8)
assert_allclose(solution, real_solution)
solution = flx.aitken(f, feed, convergenceiter=4, xtol=5e-8)
assert_allclose(solution, real_solution)

def test_fixedpoint_array_solvers2():
original_feed = feed.copy()
Expand All @@ -95,6 +90,8 @@ def test_fixedpoint_array_solvers2():
solution = flx.wegstein(f2, feed, convergenceiter=4, xtol=1e-8, maxiter=200)
with pytest.raises(RuntimeError):
solution = flx.wegstein(f2, feed, xtol=1e-8, maxiter=20)
with pytest.raises(RuntimeError):
solution = flx.aitken(f2, feed, convergenceiter=4, xtol=1e-8, maxiter=200)

solution = flx.wegstein(p, feed, checkconvergence=False, convergenceiter=4, xtol=1e-8)
p.archive('Wegstein early termination')
Expand Down Expand Up @@ -135,6 +132,13 @@ def test_fixedpoint_array_solvers2():
'Lstsq': 27, 'Lstsq early termination': 29,
'Fixed point': 191, 'Fixed point early termination': 191}

def test_fixedpoint_array_with_speed_up():
flx.speed_up()
solution = flx.wegstein(f, feed, convergenceiter=4, xtol=5e-8)
assert_allclose(solution, real_solution)
solution = flx.aitken(f, feed, convergenceiter=4, xtol=5e-8)
assert_allclose(solution, real_solution)

def test_conditional_fixedpoint_array_solvers():
original_feed = feed.copy()

Expand Down

0 comments on commit 6b36dbc

Please sign in to comment.