Skip to content

Commit

Permalink
BUG: integrate: fix romberg termination condition, and add support fo…
Browse files Browse the repository at this point in the history
…r relative tolerance in stopping condition
  • Loading branch information
pv committed Jul 28, 2010
1 parent 8fb9bd5 commit da3b913
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
24 changes: 16 additions & 8 deletions scipy/integrate/quadrature.py
Expand Up @@ -513,7 +513,7 @@ def _printresmat(function, interval, resmat):
print 'The final result is', resmat[i][j],
print 'after', 2**(len(resmat)-1)+1, 'function evaluations.'

def romberg(function, a, b, args=(), tol=1.48E-8, show=False,
def romberg(function, a, b, args=(), tol=1.48e-8, rtol=1.48e-8, show=False,
divmax=10, vec_func=False):
"""
Romberg integration of a callable function or method.
Expand Down Expand Up @@ -545,12 +545,12 @@ def romberg(function, a, b, args=(), tol=1.48E-8, show=False,
Extra arguments to pass to function. Each element of `args` will
be passed as a single argument to `func`. Default is to pass no
extra arguments.
tol : float, optional
The desired tolerance. Default is 1.48e-8.
tol, rtol : float, optional
The desired absolute and relative tolerances. Defaults are 1.48e-8.
show : bool, optional
Whether to print the results. Default is False.
divmax : int, optional
?? Default is 10.
Maximum order of extrapolation. Default is 10.
vec_func : bool, optional
Whether `func` handles arrays as arguments (i.e whether it is a
"vector" function). Default is False.
Expand Down Expand Up @@ -596,14 +596,14 @@ def romberg(function, a, b, args=(), tol=1.48E-8, show=False,
if isinf(a) or isinf(b):
raise ValueError("Romberg integration only available for finite limits.")
vfunc = vectorize1(function, args, vec_func=vec_func)
i = n = 1
n = 1
interval = [a,b]
intrange = b-a
ordsum = _difftrap(vfunc, interval, n)
result = intrange * ordsum
resmat = [[result]]
lastresult = result + tol * 2.0
while (abs(result - lastresult) > tol) and (i <= divmax):
err = np.inf
for i in xrange(1, divmax+1):
n = n * 2
ordsum = ordsum + _difftrap(vfunc, interval, n)
resmat.append([])
Expand All @@ -612,7 +612,15 @@ def romberg(function, a, b, args=(), tol=1.48E-8, show=False,
resmat[i].append(_romberg_diff(resmat[i-1][k], resmat[i][k], k+1))
result = resmat[i][i]
lastresult = resmat[i-1][i-1]
i = i + 1

err = abs(result - lastresult)
if err < tol or err < rtol*abs(result):
break
else:
warnings.warn(
"divmax (%d) exceeded. Latest difference = %e" % (divmax, err),
AccuracyWarning)

if show:
_printresmat(vfunc, interval, resmat)
return result
Expand Down
8 changes: 8 additions & 0 deletions scipy/integrate/tests/test_quadrature.py
Expand Up @@ -32,6 +32,14 @@ def myfunc(x, n, z): # Bessel function integrand
table_val = 0.30614353532540296487
assert_almost_equal(val, table_val, decimal=7)

def test_romberg_rtol(self):
# Typical function with two extra arguments:
def myfunc(x, n, z): # Bessel function integrand
return 1e19*cos(n*x-z*sin(x))/pi
val = romberg(myfunc,0,pi, args=(2, 1.8), rtol=1e-10)
table_val = 1e19*0.30614353532540296487
assert_allclose(val, table_val, rtol=1e-10)

def test_romb(self):
assert_equal(romb(numpy.arange(17)),128)

Expand Down

0 comments on commit da3b913

Please sign in to comment.