Permalink
Browse files

Refactor code related to flag validation.

  • Loading branch information...
1 parent 3818a1f commit 0194f89166d2cdfcd3eb8da375d06d74edbfbebb @pchanial committed May 9, 2012
Showing with 35 additions and 25 deletions.
  1. +35 −25 pyoperators/core.py
View
60 pyoperators/core.py
@@ -1170,31 +1170,17 @@ def _set_flags(self, flags=None, **keywords):
flags = keywords
if isinstance(flags, OperatorFlags):
self.flags = flags
- elif isinstance(flags, (dict, list, tuple, str)):
- if isinstance(flags, str):
- flags = [f.strip() for f in flags.split(',')]
- elif isscalar(flags):
- flags = (flags,)
- if isinstance(flags, (list, tuple)):
- flags = dict((f,True) for f in flags)
- if any(not isinstance(f, str) for f in flags.keys()):
- raise TypeError("Invalid type for the operator flags: {0}." \
- .format(flags))
- if any(f not in OperatorFlags._fields for f in flags):
- raise ValueError("Invalid operator flags '{0}'. The properties "
- "must be one of the following: ".format(flags.keys()) + \
- strenum(OperatorFlags._fields) + '.')
- self.flags = self.flags._replace(**flags)
- flags = [ f for f in flags if flags[f]]
- if 'symmetric' in flags or 'hermitian' in flags or \
- 'orthogonal' in flags or 'unitary' in flags:
- self.flags = self.flags._replace(linear=True, square=True)
- if 'orthogonal' in flags:
- self.flags = self.flags._replace(real=True)
- if 'involutary' in flags:
- self.flags = self.flags._replace(square=True)
- elif flags is not None:
- raise TypeError("Invalid input flags: '{0}'.".format(flags))
+ return
+ flags = self._validate_flags(flags)
+ f = [k for k,v in flags.items() if v]
+ if 'symmetric' in f or 'hermitian' in f or 'orthogonal' in f or \
+ 'unitary' in f:
+ flags['linear'] = flags['square'] = True
+ if 'orthogonal' in f:
+ flags['real'] = True
+ if 'involutary' in f:
+ flags['square'] = True
+ self.flags = self.flags._replace(**flags)
def _validate_arguments(self, input, output):
"""
@@ -1231,6 +1217,30 @@ def _validate_arguments(self, input, output):
output = memory.allocate(shapeout, dtype, 'in ' + self.__name__)
return input, output
+ @staticmethod
+ def _validate_flags(flags):
+ """ Return flags as a dictionary. """
+ if flags is None:
+ return {}
+ if isinstance(flags, dict):
+ return flags.copy()
+ if isinstance(flags, OperatorFlags):
+ return dict((k,v) for k,v in zip(OperatorFlags._fields, flags))
+ if isinstance(flags, str):
+ flags = [f.strip() for f in flags.split(',')]
+ if not isinstance(flags, (list, tuple)):
+ raise TypeError("The operator flags have an invalid type '{0}'.".
+ format(flags))
+ flags = dict((f,True) for f in flags)
+ if any(not isinstance(f, str) for f in flags.keys()):
+ raise TypeError("Invalid type for the operator flags: {0}." \
+ .format(flags))
+ if any(f not in OperatorFlags._fields for f in flags):
+ raise ValueError("Invalid operator flags '{0}'. The properties "
+ "must be one of the following: ".format(flags.keys()) + \
+ strenum(OperatorFlags._fields) + '.')
+ return flags
+
def __mul__(self, other):
if isinstance(other, np.ndarray):
return self.matvec(other)

0 comments on commit 0194f89

Please sign in to comment.