diff --git a/scipy/lib/decorator.py b/scipy/lib/decorator.py new file mode 100644 index 000000000000..2e8c123380e8 --- /dev/null +++ b/scipy/lib/decorator.py @@ -0,0 +1,210 @@ +########################## LICENCE ############################### +## +## Copyright (c) 2005-2011, Michele Simionato +## All rights reserved. +## +## Redistributions of source code must retain the above copyright +## notice, this list of conditions and the following disclaimer. +## Redistributions in bytecode form must reproduce the above copyright +## notice, this list of conditions and the following disclaimer in +## the documentation and/or other materials provided with the +## distribution. + +## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +## DAMAGE. + +""" +Decorator module, see http://pypi.python.org/pypi/decorator +for the documentation. +""" + +__version__ = '3.3.2' + +__all__ = ["decorator", "FunctionMaker", "partial"] + +import sys, re, inspect + +try: + from functools import partial +except ImportError: # for Python version < 2.5 + class partial(object): + "A simple replacement of functools.partial" + def __init__(self, func, *args, **kw): + self.func = func + self.args = args + self.keywords = kw + def __call__(self, *otherargs, **otherkw): + kw = self.keywords.copy() + kw.update(otherkw) + return self.func(*(self.args + otherargs), **kw) + +if sys.version >= '3': + from inspect import getfullargspec +else: + class getfullargspec(object): + "A quick and dirty replacement for getfullargspec for Python 2.X" + def __init__(self, f): + self.args, self.varargs, self.varkw, self.defaults = \ + inspect.getargspec(f) + self.kwonlyargs = [] + self.kwonlydefaults = None + self.annotations = getattr(f, '__annotations__', {}) + def __iter__(self): + yield self.args + yield self.varargs + yield self.varkw + yield self.defaults + +DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') + +# basic functionality +class FunctionMaker(object): + """ + An object with the ability to create functions with a given signature. + It has attributes name, doc, module, signature, defaults, dict and + methods update and make. + """ + def __init__(self, func=None, name=None, signature=None, + defaults=None, doc=None, module=None, funcdict=None): + self.shortsignature = signature + if func: + # func can be a class or a callable, but not an instance method + self.name = func.__name__ + if self.name == '': # small hack for lambda functions + self.name = '_lambda_' + self.doc = func.__doc__ + self.module = func.__module__ + if inspect.isfunction(func): + argspec = getfullargspec(func) + for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', + 'kwonlydefaults', 'annotations'): + setattr(self, a, getattr(argspec, a)) + for i, arg in enumerate(self.args): + setattr(self, 'arg%d' % i, arg) + self.signature = inspect.formatargspec( + formatvalue=lambda val: "", *argspec)[1:-1] + allargs = list(self.args) + if self.varargs: + allargs.append('*' + self.varargs) + if self.varkw: + allargs.append('**' + self.varkw) + try: + self.shortsignature = ', '.join(allargs) + except TypeError: # exotic signature, valid only in Python 2.X + self.shortsignature = self.signature + self.dict = func.__dict__.copy() + # func=None happens when decorating a caller + if name: + self.name = name + if signature is not None: + self.signature = signature + if defaults: + self.defaults = defaults + if doc: + self.doc = doc + if module: + self.module = module + if funcdict: + self.dict = funcdict + # check existence required attributes + assert hasattr(self, 'name') + if not hasattr(self, 'signature'): + raise TypeError('You are decorating a non function: %s' % func) + + def update(self, func, **kw): + "Update the signature of func with the data in self" + func.__name__ = self.name + func.__doc__ = getattr(self, 'doc', None) + func.__dict__ = getattr(self, 'dict', {}) + func.func_defaults = getattr(self, 'defaults', ()) + func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None) + callermodule = sys._getframe(3).f_globals.get('__name__', '?') + func.__module__ = getattr(self, 'module', callermodule) + func.__dict__.update(kw) + + def make(self, src_templ, evaldict=None, addsource=False, **attrs): + "Make a new function from a given template and update the signature" + src = src_templ % vars(self) # expand name and signature + evaldict = evaldict or {} + mo = DEF.match(src) + if mo is None: + raise SyntaxError('not a valid function template\n%s' % src) + name = mo.group(1) # extract the function name + names = set([name] + [arg.strip(' *') for arg in + self.shortsignature.split(',')]) + for n in names: + if n in ('_func_', '_call_'): + raise NameError('%s is overridden in\n%s' % (n, src)) + if not src.endswith('\n'): # add a newline just for safety + src += '\n' # this is needed in old versions of Python + try: + code = compile(src, '', 'single') + # print >> sys.stderr, 'Compiling %s' % src + exec code in evaldict + except: + print >> sys.stderr, 'Error in generated code:' + print >> sys.stderr, src + raise + func = evaldict[name] + if addsource: + attrs['__source__'] = src + self.update(func, **attrs) + return func + + @classmethod + def create(cls, obj, body, evaldict, defaults=None, + doc=None, module=None, addsource=True, **attrs): + """ + Create a function from the strings name, signature and body. + evaldict is the evaluation dictionary. If addsource is true an attribute + __source__ is added to the result. The attributes attrs are added, + if any. + """ + if isinstance(obj, str): # "name(signature)" + name, rest = obj.strip().split('(', 1) + signature = rest[:-1] #strip a right parens + func = None + else: # a function + name = None + signature = None + func = obj + self = cls(func, name, signature, defaults, doc, module) + ibody = '\n'.join(' ' + line for line in body.splitlines()) + return self.make('def %(name)s(%(signature)s):\n' + ibody, + evaldict, addsource, **attrs) + +def decorator(caller, func=None): + """ + decorator(caller) converts a caller function into a decorator; + decorator(caller, func) decorates a function using a caller. + """ + if func is not None: # returns a decorated function + evaldict = func.func_globals.copy() + evaldict['_call_'] = caller + evaldict['_func_'] = func + return FunctionMaker.create( + func, "return _call_(_func_, %(shortsignature)s)", + evaldict, undecorated=func, __wrapped__=func) + else: # returns a decorator + if isinstance(caller, partial): + return partial(decorator, caller) + # otherwise assume caller is a function + first = inspect.getargspec(caller)[0][0] # first arg + evaldict = caller.func_globals.copy() + evaldict['_call_'] = caller + evaldict['decorator'] = decorator + return FunctionMaker.create( + '%s(%s)' % (caller.__name__, first), + 'return decorator(_call_, %s)' % first, + evaldict, undecorated=caller, __wrapped__=caller, + doc=caller.__doc__, module=caller.__module__) diff --git a/scipy/sparse/linalg/isolve/iterative.py b/scipy/sparse/linalg/isolve/iterative.py index 528c63fe92fe..4d25ff751b61 100644 --- a/scipy/sparse/linalg/isolve/iterative.py +++ b/scipy/sparse/linalg/isolve/iterative.py @@ -6,6 +6,7 @@ import numpy as np from scipy.sparse.linalg.interface import LinearOperator +from scipy.lib.decorator import decorator from utils import make_system _type_conv = {'f':'s', 'd':'d', 'F':'c', 'D':'z'} @@ -71,12 +72,22 @@ def combine(fn): return fn return combine - +@decorator +def non_reentrant(func, *a, **kw): + d = func.__dict__ + if d.get('__entered'): + raise RuntimeError("%s is not re-entrant" % func.__name__) + try: + d['__entered'] = True + return func(*a, **kw) + finally: + d['__entered'] = False @set_docstring('Use BIConjugate Gradient iteration to solve A x = b', 'The real or complex N-by-N matrix of the linear system\n' 'It is required that the linear operator can produce\n' '``Ax`` and ``A^T x``.') +@non_reentrant def bicg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None): A,M,x,b,postprocess = make_system(A,M,x0,b,xtype) @@ -140,6 +151,7 @@ def bicg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=Non @set_docstring('Use BIConjugate Gradient STABilized iteration to solve A x = b', 'The real or complex N-by-N matrix of the linear system\n' '``A`` must represent a hermitian, positive definite matrix') +@non_reentrant def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None): A,M,x,b,postprocess = make_system(A,M,x0,b,xtype) @@ -200,6 +212,7 @@ def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback @set_docstring('Use Conjugate Gradient iteration to solve A x = b', 'The real or complex N-by-N matrix of the linear system\n' '``A`` must represent a hermitian, positive definite matrix') +@non_reentrant def cg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None): A,M,x,b,postprocess = make_system(A,M,x0,b,xtype) @@ -259,6 +272,7 @@ def cg(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None) @set_docstring('Use Conjugate Gradient Squared iteration to solve A x = b', 'The real-valued N-by-N matrix of the linear system') +@non_reentrant def cgs(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None): A,M,x,b,postprocess = make_system(A,M,x0,b,xtype) @@ -314,7 +328,7 @@ def cgs(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M=None, callback=None return postprocess(x), info - +@non_reentrant def gmres(A, b, x0=None, tol=1e-5, restart=None, maxiter=None, xtype=None, M=None, callback=None, restrt=None): """ Use Generalized Minimal RESidual iteration to solve A x = b. @@ -481,6 +495,7 @@ def gmres(A, b, x0=None, tol=1e-5, restart=None, maxiter=None, xtype=None, M=Non return postprocess(x), info +@non_reentrant def qmr(A, b, x0=None, tol=1e-5, maxiter=None, xtype=None, M1=None, M2=None, callback=None): """Use Quasi-Minimal Residual iteration to solve A x = b diff --git a/scipy/sparse/linalg/isolve/tests/test_iterative.py b/scipy/sparse/linalg/isolve/tests/test_iterative.py index ca003eaa536c..5c51fbc663bf 100644 --- a/scipy/sparse/linalg/isolve/tests/test_iterative.py +++ b/scipy/sparse/linalg/isolve/tests/test_iterative.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import TestCase, assert_equal, assert_array_equal, \ - assert_, assert_allclose + assert_, assert_allclose, assert_raises from numpy import zeros, ones, arange, array, abs, max from numpy.linalg import cond @@ -208,6 +208,29 @@ def test_gmres_basic(): assert_allclose(x_gm[0], 0.359, rtol=1e-2) +def test_reentrancy(): + non_reentrant = [cg, cgs, bicg, bicgstab, gmres, qmr] + reentrant = [lgmres, minres] + for solver in reentrant + non_reentrant: + yield _check_reentrancy, solver, solver in reentrant + +def _check_reentrancy(solver, is_reentrant): + def matvec(x): + A = np.array([[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]]) + y, info = solver(A, x) + assert_equal(info, 0) + return y + b = np.array([1, 1./2, 1./3]) + op = LinearOperator((3, 3), matvec=matvec, rmatvec=matvec, + dtype=b.dtype) + + if not is_reentrant: + assert_raises(RuntimeError, solver, op, b) + else: + y, info = solver(op, b) + assert_equal(info, 0) + assert_allclose(y, [1, 1, 1]) + #------------------------------------------------------------------------------