Skip to content

Commit

Permalink
Merge pull request #791 from pymor/project_nobases
Browse files Browse the repository at this point in the history
Add rule to ProjectRules for the case that source_basis range basis are None
  • Loading branch information
sdrave committed Oct 18, 2019
2 parents 7499317 + b6818fd commit b3519e6
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/pymor/algorithms/projection.py
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from pymor.algorithms.rules import RuleTable, match_class, match_generic
from pymor.algorithms.rules import RuleTable, match_class, match_generic, match_always
from pymor.core.exceptions import RuleNotMatchingError, NoMatchingRuleError
from pymor.operators.basic import ProjectedOperator
from pymor.operators.block import BlockOperatorBase, BlockRowOperator, BlockColumnOperator
Expand Down Expand Up @@ -80,6 +80,13 @@ def __init__(self, range_basis, source_basis):
super().__init__(use_caching=True)
self.__auto_init(locals())

@match_always
def action_no_bases(self, op):
if self.range_basis is None and self.source_basis is None:
return op
else:
raise RuleNotMatchingError

@match_class(ZeroOperator)
def action_ZeroOperator(self, op):
range_basis, source_basis = self.range_basis, self.source_basis
Expand Down Expand Up @@ -108,19 +115,16 @@ def action_ConstantOperator(self, op):
def action_apply_basis(self, op):
range_basis, source_basis = self.range_basis, self.source_basis
if source_basis is None:
if range_basis is None:
return op
try:
V = op.apply_adjoint(range_basis)
except NotImplementedError:
raise RuleNotMatchingError('apply_adjoint not implemented')
if isinstance(op.source, NumpyVectorSpace):
from pymor.operators.numpy import NumpyMatrixOperator
return NumpyMatrixOperator(V.to_numpy(), source_id=op.source.id, name=op.name)
else:
try:
V = op.apply_adjoint(range_basis)
except NotImplementedError:
raise RuleNotMatchingError('apply_adjoint not implemented')
if isinstance(op.source, NumpyVectorSpace):
from pymor.operators.numpy import NumpyMatrixOperator
return NumpyMatrixOperator(V.to_numpy(), source_id=op.source.id, name=op.name)
else:
from pymor.operators.constructions import VectorArrayOperator
return VectorArrayOperator(V, adjoint=True, name=op.name)
from pymor.operators.constructions import VectorArrayOperator
return VectorArrayOperator(V, adjoint=True, name=op.name)
else:
if range_basis is None:
V = op.apply(source_basis)
Expand Down

0 comments on commit b3519e6

Please sign in to comment.