Browse files

new standard_call() decorator for linear solvers, checks also data sh…

…apes
  • Loading branch information...
1 parent abb78e3 commit d0fa5be9d2acab89d8d42af286a85c5ddb20efed @rc committed Jul 4, 2012
Showing with 39 additions and 20 deletions.
  1. +39 −20 sfepy/solvers/ls.py
View
59 sfepy/solvers/ls.py
@@ -1,3 +1,5 @@
+import time
+
import numpy as nm
import warnings
@@ -8,6 +10,33 @@
from sfepy.base.base import output, get_default, assert_, try_imports, Struct
from sfepy.solvers.solvers import make_get_conf, LinearSolver
+def standard_call(call):
+ """
+ Decorator handling argument preparation and timing for linear solvers.
+ """
+ def _standard_call(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
+ i_max=None, mtx=None, status=None, **kwargs):
+ tt = time.clock()
+
+ conf = get_default(conf, self.conf)
+ mtx = get_default(mtx, self.mtx)
+ status = get_default(status, self.status)
+
+ assert_(mtx.shape[0] == mtx.shape[1] == rhs.shape[0])
+ if x0 is not None:
+ assert_(x0.shape[0] == rhs.shape[0])
+
+ result = call(self, rhs, x0, conf, eps_a, eps_r, i_max, mtx, status,
+ **kwargs)
+
+ ttt = time.clock() - tt
+ if status is not None:
+ status['time'] = ttt
+
+ return result
+
+ return _standard_call
+
class ScipyDirect(LinearSolver):
name = 'ls.scipy_direct'
@@ -74,11 +103,9 @@ def __init__(self, conf, **kwargs):
if self.mtx is not None:
self.solve = self.sls.factorized(self.mtx)
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- conf = get_default(conf, self.conf)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
if self.solve is not None:
# Matrix is already prefactorized.
@@ -163,13 +190,12 @@ def __init__(self, conf, **kwargs):
-1 : 'illegal input or breakdown',
}
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- conf = get_default(conf, self.conf)
+
eps_r = get_default(eps_r, self.conf.eps_r)
i_max = get_default(i_max, self.conf.i_max)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
precond = get_default(kwargs.get('precond', None), self.conf.precond)
callback = get_default(kwargs.get('callback', None), self.conf.callback)
@@ -247,12 +273,11 @@ def __init__( self, conf, **kwargs ):
if self.mtx is not None:
self.mg = self.solver( self.mtx )
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- conf = get_default(conf, self.conf)
+
eps_r = get_default(eps_r, self.eps_r)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
if (self.mg is None) or (mtx is not self.mtx):
self.mg = self.solver(mtx)
@@ -345,14 +370,13 @@ def set_matrix( self, mtx ):
sol, rhs = pmtx.getVecs()
return pmtx, sol, rhs
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- conf = get_default(conf, self.conf)
+
eps_a = get_default(eps_a, self.eps_a)
eps_r = get_default(eps_r, self.eps_r)
i_max = get_default(i_max, self.conf.i_max)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
# There is no use in caching matrix in the solver - always set as new.
pmtx, psol, prhs = self.set_matrix(mtx)
@@ -419,18 +443,16 @@ def process_conf(conf, kwargs):
return Struct(n_proc=get('n_proc', 1),
sub_precond=get('sub_precond', 'icc')) + common
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- import os, sys, shutil, tempfile, time
+ import os, sys, shutil, tempfile
from sfepy import base_dir, data_dir
from sfepy.base.ioutils import ensure_path
- conf = get_default(conf, self.conf)
eps_a = get_default(eps_a, self.eps_a)
eps_r = get_default(eps_r, self.eps_r)
i_max = get_default(i_max, self.conf.i_max)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
petsc = self.petsc
@@ -573,13 +595,10 @@ def __init__(self, conf, **kwargs):
vec0 = aux_state.get_reduced()
conf.idxs[bk] = nm.where(nm.isnan(vec0))[0]
+ @standard_call
def __call__(self, rhs, x0=None, conf=None, eps_a=None, eps_r=None,
i_max=None, mtx=None, status=None, **kwargs):
- conf = get_default(conf, self.conf)
- mtx = get_default(mtx, self.mtx)
- status = get_default(status, self.status)
-
mtxi= self.orig_conf.idxs
mtxslc_s = {}
mtxslc_f = {}

0 comments on commit d0fa5be

Please sign in to comment.