Skip to content

Commit

Permalink
Implement operators DiagonalNumexprOperator and DiagonalNumexprSepara…
Browse files Browse the repository at this point in the history
…bleOperator.
  • Loading branch information
pchanial committed Jul 19, 2012
1 parent 09df6d9 commit d6865da
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 4 deletions.
116 changes: 115 additions & 1 deletion pyoperators/linear.py
@@ -1,15 +1,20 @@
from __future__ import division

import numexpr
import numpy as np

from scipy.sparse.linalg import eigsh

from .decorators import linear, real, symmetric, inplace
from .core import Operator, BlockRowOperator, CompositionOperator, DiagonalOperator, ReductionOperator, asoperator
from .core import (Operator, BlockRowOperator, BroadcastingOperator,
CompositionOperator, DiagonalOperator, ReductionOperator,
asoperator)
from .utils import isscalar

__all__ = [
'BandOperator',
'DiagonalNumexprOperator',
'DiagonalNumexprSeparableOperator',
'EigendecompositionOperator',
'IntegrationTrapezeWeightOperator',
'PackOperator',
Expand All @@ -20,6 +25,115 @@
]


class DiagonalNumexprOperator(DiagonalOperator):
"""
DiagonalOperator whose diagonal elements are calculated on the fly using
the numexpr package.
Notes
-----
- When such instance is added or multiplied to another DiagonalOperator
(or subclass, such as an instance of this class), an algebraic
simplification takes place, which results in a regular (dense) diagonal
operator.
- This operator can not be separated so that each part handles a block
of a block operator. Also, rightward broadcasting cannot be used. If one of
these properties is desired, use the class DiagonalNumexprSeparableOperator.
- If the operator's input shape is not specified, its inference costs
an evaluation of the expression.
Example
-------
>>> alpha = np.arange(100.)
>>> d = DiagonalNumexprOperator('(x/x0)**alpha',
{'alpha':alpha, 'x':1.2, 'x0':1.})
"""
def __init__(self, expr, global_dict=None, dtype=float, **keywords):
if not isinstance(expr, str):
raise TypeError('The first argument is not a string expression.')
if 'broadcast' in keywords and keywords['broadcast'] == 'rightward':
raise ValueError('The class DiagonalNumexprOperator does not handle'
' rightward broadcasting. Use the class DiagonalNu'
'exprSeparableOperator for this purpose.')
if 'broadcast' not in keywords or keywords['broadcast'] != 'leftward':
keywords['broadcast'] = 'disabled'
self.expr = expr
self.global_dict = global_dict
if 'shapein' not in keywords and 'shapeout' not in keywords and \
keywords['broadcast'] == 'disabled':
keywords['shapein'] = self.get_data().shape
BroadcastingOperator.__init__(self, 0, dtype=dtype, **keywords)

def direct(self, input, output):
numexpr.evaluate('(' + self.expr + ') * input',
global_dict=self.global_dict, out=output)

def get_data(self):
return numexpr.evaluate(self.expr, global_dict=self.global_dict)

@staticmethod
def _rule_left_block(self, op, cls):
return None

@staticmethod
def _rule_right_block(self, op, cls):
return None


class DiagonalNumexprSeparableOperator(DiagonalOperator):
"""
DiagonalOperator whose diagonal elements are calculated on the fly using
the numexpr package and that can be seperated when added or multiplied
to a block operator.
Note
----
When such instance is added or multiplied to another DiagonalOperator
(or subclass, such as an instance of this class), an algebraic
simplification takes place, which results in a regular (dense) diagonal
operator.
Example
-------
>>> alpha = np.arange(100.)
>>> d = SeparableDiagonalNumexprOperator(alpha, '(x/x0)**data',
{'x':1.2, 'x0':1.})
"""
def __init__(self, data, expr, global_dict=None, var='data', dtype=float,
**keywords):
if not isinstance(expr, str):
raise TypeError('The second argument is not a string expression.')
BroadcastingOperator.__init__(self, data, dtype=dtype, **keywords)
self.expr = expr
self.var = var
self.global_dict = global_dict
self._global_dict = {} if global_dict is None else global_dict.copy()
self._global_dict[var] = self.data.T if self.broadcast == \
'rightward' else self.data

def direct(self, input, output):
if self.broadcast == 'rightward':
input = input.T
output = output.T
numexpr.evaluate('(' + self.expr + ') * input',
global_dict=self._global_dict, out=output)

def get_data(self):
local_dict = {self.var:self.data}
return numexpr.evaluate(self.expr, local_dict=local_dict,
global_dict=self.global_dict)

@staticmethod
def _rule_block(self, op, shape, partition, axis, new_axis, func_operation):
if type(self) is not DiagonalNumexprSeparableOperator:
return None
return DiagonalOperator._rule_block(self, op, shape, partition, axis,
new_axis, func_operation, self.expr, global_dict=
self.global_dict, var=self.var)


class IntegrationTrapezeWeightOperator(BlockRowOperator):
"""
Return weights as a block row operator to perform trapeze integration.
Expand Down
67 changes: 64 additions & 3 deletions tests/test_linear.py
Expand Up @@ -3,13 +3,74 @@
import numpy as np
import pyoperators

from pyoperators import Operator, BlockColumnOperator
from pyoperators.linear import IntegrationTrapezeWeightOperator, PackOperator, UnpackOperator, SumOperator
from pyoperators.utils.testing import assert_eq
from pyoperators import Operator, BlockColumnOperator, DiagonalOperator
from pyoperators.linear import (DiagonalNumexprOperator,
DiagonalNumexprSeparableOperator,
IntegrationTrapezeWeightOperator,
PackOperator, UnpackOperator, SumOperator)
from pyoperators.utils.testing import (assert_eq, assert_is_instance,
assert_is_none, assert_raises)
from .common import TestIdentityOperator, assert_inplace_outplace

SHAPES = (None, (), (1,), (3,), (2,3), (2,3,4))


def test_diagonal_numexpr():
diag = np.array([1, 2, 3])
expr = '(data+1)*3'
def func1(cls, args, broadcast):
assert_raises(ValueError, cls, broadcast=broadcast, *args)
def func2(cls, args, broadcast, values):
if broadcast == 'rightward':
expected = (values.T*(diag.T+1)*3).T
else:
expected = values*(diag+1)*3
op = cls(broadcast=broadcast, *args)
if broadcast in ('leftward', 'rightward'):
assert op.broadcast == broadcast
assert_is_none(op.shapein)
else:
assert op.broadcast == 'disabled'
assert_eq(op.shapein, diag.shape)
assert_inplace_outplace(op, values, expected)
for cls, args in zip((DiagonalNumexprOperator,
DiagonalNumexprSeparableOperator),
((expr, {'data':diag}), (diag, expr))):
for broadcast in (None, 'rightward', 'leftward', 'disabled'):
if cls is DiagonalNumexprOperator and broadcast == 'rightward':
yield func1, cls, args, broadcast
continue
for values in (np.array([3,2,1.]),
np.array([[1,2,3],[2,3,4],[3,4,5.]])):
if values.ndim > 1 and broadcast in (None, 'disabled'):
continue
yield func2, cls, args, broadcast, values

def test_diagonal_numexpr2():
diag = np.array([1, 2, 3])
d1 = DiagonalNumexprOperator('(data+1)*3', {'data':diag})
d2 = DiagonalNumexprOperator('(data+2)*2', {'data':np.array([3,2,1])})
d = d1 * d2
assert_is_instance(d, DiagonalOperator)
assert_eq(d.broadcast, 'disabled')
assert_eq(d.shapein, (3,))
assert_eq(d.data, [60, 72, 72])
c = BlockColumnOperator(3*[TestIdentityOperator()], new_axisout=0)
v = 2
assert_inplace_outplace(d1*c, v, d1(c(v)))

def test_diagonal_numexpr3():
d1 = DiagonalNumexprSeparableOperator([1,2,3], '(data+1)*3',
broadcast='rightward')
d2 = DiagonalNumexprSeparableOperator([3,2,1], '(data+2)*2')
d = d1 * d2
assert_is_instance(d, DiagonalOperator)
assert_eq(d.broadcast, 'disabled')
assert_eq(d.data, [60, 72, 72])
c = BlockColumnOperator(3*[TestIdentityOperator()], new_axisout=0)
v = [1,2]
assert_inplace_outplace(d1*c, v, d1(c(v)))

def test_integration_trapeze():
@pyoperators.decorators.square
class Op(Operator):
Expand Down

0 comments on commit d6865da

Please sign in to comment.