Permalink
Browse files

Implement operators DiagonalNumexprOperator and DiagonalNumexprSepara…

…bleOperator.
  • Loading branch information...
1 parent 09df6d9 commit d6865da7fb9c94468a9009469e22cb49bc01fc07 @pchanial committed Jul 19, 2012
Showing with 179 additions and 4 deletions.
  1. +115 −1 pyoperators/linear.py
  2. +64 −3 tests/test_linear.py
View
@@ -1,15 +1,20 @@
from __future__ import division
+import numexpr
import numpy as np
from scipy.sparse.linalg import eigsh
from .decorators import linear, real, symmetric, inplace
-from .core import Operator, BlockRowOperator, CompositionOperator, DiagonalOperator, ReductionOperator, asoperator
+from .core import (Operator, BlockRowOperator, BroadcastingOperator,
+ CompositionOperator, DiagonalOperator, ReductionOperator,
+ asoperator)
from .utils import isscalar
__all__ = [
'BandOperator',
+ 'DiagonalNumexprOperator',
+ 'DiagonalNumexprSeparableOperator',
'EigendecompositionOperator',
'IntegrationTrapezeWeightOperator',
'PackOperator',
@@ -20,6 +25,115 @@
]
+class DiagonalNumexprOperator(DiagonalOperator):
+ """
+ DiagonalOperator whose diagonal elements are calculated on the fly using
+ the numexpr package.
+
+ Notes
+ -----
+ - When such instance is added or multiplied to another DiagonalOperator
+ (or subclass, such as an instance of this class), an algebraic
+ simplification takes place, which results in a regular (dense) diagonal
+ operator.
+ - This operator can not be separated so that each part handles a block
+ of a block operator. Also, rightward broadcasting cannot be used. If one of
+ these properties is desired, use the class DiagonalNumexprSeparableOperator.
+ - If the operator's input shape is not specified, its inference costs
+ an evaluation of the expression.
+
+ Example
+ -------
+ >>> alpha = np.arange(100.)
+ >>> d = DiagonalNumexprOperator('(x/x0)**alpha',
+ {'alpha':alpha, 'x':1.2, 'x0':1.})
+
+ """
+ def __init__(self, expr, global_dict=None, dtype=float, **keywords):
+ if not isinstance(expr, str):
+ raise TypeError('The first argument is not a string expression.')
+ if 'broadcast' in keywords and keywords['broadcast'] == 'rightward':
+ raise ValueError('The class DiagonalNumexprOperator does not handle'
+ ' rightward broadcasting. Use the class DiagonalNu'
+ 'exprSeparableOperator for this purpose.')
+ if 'broadcast' not in keywords or keywords['broadcast'] != 'leftward':
+ keywords['broadcast'] = 'disabled'
+ self.expr = expr
+ self.global_dict = global_dict
+ if 'shapein' not in keywords and 'shapeout' not in keywords and \
+ keywords['broadcast'] == 'disabled':
+ keywords['shapein'] = self.get_data().shape
+ BroadcastingOperator.__init__(self, 0, dtype=dtype, **keywords)
+
+ def direct(self, input, output):
+ numexpr.evaluate('(' + self.expr + ') * input',
+ global_dict=self.global_dict, out=output)
+
+ def get_data(self):
+ return numexpr.evaluate(self.expr, global_dict=self.global_dict)
+
+ @staticmethod
+ def _rule_left_block(self, op, cls):
+ return None
+
+ @staticmethod
+ def _rule_right_block(self, op, cls):
+ return None
+
+
+class DiagonalNumexprSeparableOperator(DiagonalOperator):
+ """
+ DiagonalOperator whose diagonal elements are calculated on the fly using
+ the numexpr package and that can be seperated when added or multiplied
+ to a block operator.
+
+ Note
+ ----
+ When such instance is added or multiplied to another DiagonalOperator
+ (or subclass, such as an instance of this class), an algebraic
+ simplification takes place, which results in a regular (dense) diagonal
+ operator.
+
+ Example
+ -------
+ >>> alpha = np.arange(100.)
+ >>> d = SeparableDiagonalNumexprOperator(alpha, '(x/x0)**data',
+ {'x':1.2, 'x0':1.})
+
+ """
+ def __init__(self, data, expr, global_dict=None, var='data', dtype=float,
+ **keywords):
+ if not isinstance(expr, str):
+ raise TypeError('The second argument is not a string expression.')
+ BroadcastingOperator.__init__(self, data, dtype=dtype, **keywords)
+ self.expr = expr
+ self.var = var
+ self.global_dict = global_dict
+ self._global_dict = {} if global_dict is None else global_dict.copy()
+ self._global_dict[var] = self.data.T if self.broadcast == \
+ 'rightward' else self.data
+
+ def direct(self, input, output):
+ if self.broadcast == 'rightward':
+ input = input.T
+ output = output.T
+ numexpr.evaluate('(' + self.expr + ') * input',
+ global_dict=self._global_dict, out=output)
+
+ def get_data(self):
+ local_dict = {self.var:self.data}
+ return numexpr.evaluate(self.expr, local_dict=local_dict,
+ global_dict=self.global_dict)
+
+ @staticmethod
+ def _rule_block(self, op, shape, partition, axis, new_axis, func_operation):
+ if type(self) is not DiagonalNumexprSeparableOperator:
+ return None
+ return DiagonalOperator._rule_block(self, op, shape, partition, axis,
+ new_axis, func_operation, self.expr, global_dict=
+ self.global_dict, var=self.var)
+
+
class IntegrationTrapezeWeightOperator(BlockRowOperator):
"""
Return weights as a block row operator to perform trapeze integration.
View
@@ -3,13 +3,74 @@
import numpy as np
import pyoperators
-from pyoperators import Operator, BlockColumnOperator
-from pyoperators.linear import IntegrationTrapezeWeightOperator, PackOperator, UnpackOperator, SumOperator
-from pyoperators.utils.testing import assert_eq
+from pyoperators import Operator, BlockColumnOperator, DiagonalOperator
+from pyoperators.linear import (DiagonalNumexprOperator,
+ DiagonalNumexprSeparableOperator,
+ IntegrationTrapezeWeightOperator,
+ PackOperator, UnpackOperator, SumOperator)
+from pyoperators.utils.testing import (assert_eq, assert_is_instance,
+ assert_is_none, assert_raises)
+from .common import TestIdentityOperator, assert_inplace_outplace
SHAPES = (None, (), (1,), (3,), (2,3), (2,3,4))
+def test_diagonal_numexpr():
+ diag = np.array([1, 2, 3])
+ expr = '(data+1)*3'
+ def func1(cls, args, broadcast):
+ assert_raises(ValueError, cls, broadcast=broadcast, *args)
+ def func2(cls, args, broadcast, values):
+ if broadcast == 'rightward':
+ expected = (values.T*(diag.T+1)*3).T
+ else:
+ expected = values*(diag+1)*3
+ op = cls(broadcast=broadcast, *args)
+ if broadcast in ('leftward', 'rightward'):
+ assert op.broadcast == broadcast
+ assert_is_none(op.shapein)
+ else:
+ assert op.broadcast == 'disabled'
+ assert_eq(op.shapein, diag.shape)
+ assert_inplace_outplace(op, values, expected)
+ for cls, args in zip((DiagonalNumexprOperator,
+ DiagonalNumexprSeparableOperator),
+ ((expr, {'data':diag}), (diag, expr))):
+ for broadcast in (None, 'rightward', 'leftward', 'disabled'):
+ if cls is DiagonalNumexprOperator and broadcast == 'rightward':
+ yield func1, cls, args, broadcast
+ continue
+ for values in (np.array([3,2,1.]),
+ np.array([[1,2,3],[2,3,4],[3,4,5.]])):
+ if values.ndim > 1 and broadcast in (None, 'disabled'):
+ continue
+ yield func2, cls, args, broadcast, values
+
+def test_diagonal_numexpr2():
+ diag = np.array([1, 2, 3])
+ d1 = DiagonalNumexprOperator('(data+1)*3', {'data':diag})
+ d2 = DiagonalNumexprOperator('(data+2)*2', {'data':np.array([3,2,1])})
+ d = d1 * d2
+ assert_is_instance(d, DiagonalOperator)
+ assert_eq(d.broadcast, 'disabled')
+ assert_eq(d.shapein, (3,))
+ assert_eq(d.data, [60, 72, 72])
+ c = BlockColumnOperator(3*[TestIdentityOperator()], new_axisout=0)
+ v = 2
+ assert_inplace_outplace(d1*c, v, d1(c(v)))
+
+def test_diagonal_numexpr3():
+ d1 = DiagonalNumexprSeparableOperator([1,2,3], '(data+1)*3',
+ broadcast='rightward')
+ d2 = DiagonalNumexprSeparableOperator([3,2,1], '(data+2)*2')
+ d = d1 * d2
+ assert_is_instance(d, DiagonalOperator)
+ assert_eq(d.broadcast, 'disabled')
+ assert_eq(d.data, [60, 72, 72])
+ c = BlockColumnOperator(3*[TestIdentityOperator()], new_axisout=0)
+ v = [1,2]
+ assert_inplace_outplace(d1*c, v, d1(c(v)))
+
def test_integration_trapeze():
@pyoperators.decorators.square
class Op(Operator):

0 comments on commit d6865da

Please sign in to comment.