In [None]:
#| default_exp pymor.parameters

# pymor.parameters

> Extended functionality for [pyMOR](https://pymor.org/) parameters

In [None]:
#| hide
from fastcore.test import test_eq

In [None]:
#| export
from numbers import Number
from fastcore.basics import patch
import sympy as sy
from sympy.parsing.sympy_parser import parse_expr
from pymor.parameters.functionals import ProductParameterFunctional, ExpressionParameterFunctional, ParameterFunctional
from pymor.algorithms.rules import match_class, RuleTable, match_always
from pymor.basic import LincombOperator
from pymor.models.interface import Model
from pymor.operators.interface import Operator

In [None]:
#| export
@patch
def __eq__(self:ExpressionParameterFunctional, other):
    return (
        isinstance(other, ExpressionParameterFunctional)
        and self.expression == other.expression
        and self.parameters == other.parameters
    )

In [None]:
functional = ProductParameterFunctional((ExpressionParameterFunctional('3890540.14*sqrt(IntensitySI1)', {'IntensitySI1': 1}), -1.0))

In [None]:
functional == functional

True

In [None]:
#| hide
test_eq(_, True)

In [None]:
#| export
@patch
def __str__(self:ParameterFunctional):
    return f'f({", ".join(self.parameters)})'

In [None]:
str(functional)

'f(IntensitySI1)'

In [None]:
#| hide
test_eq(_, 'f(IntensitySI1)')

## Simplify functionals

In [None]:
#| export
class SimplifyFunctionalRules(RuleTable):
    """|RuleTable| for the :func:`expand` algorithm."""

    def __init__(self):
        super().__init__(use_caching=True)

    @match_class(ProductParameterFunctional)
    def action_ProductParameterFunctional(self, functional):
        # merge child ProductParameterFunctional objects
        if any(isinstance(factor, ProductParameterFunctional) for factor in functional.factors):
            factors = []
            for factor in functional_factors:
                if isinstance(factor, ProductParameterFunctional):
                    factors.extend(self.apply(factor).factors)
                else:
                    factors.append(factor)
            functional = functional.with_(factors=factors)

        # multiply together numbers and ExpressionParameterFunctional objects
        if all(isinstance(factor, (ExpressionParameterFunctional, Number)) for factor in functional.factors):
            product = sy.prod([
                parse_expr(factor.expression) if isinstance(factor, ExpressionParameterFunctional) else factor 
                for factor in functional.factors
            ])
            if product.is_number:
                functional = product
            else:
                functional = ExpressionParameterFunctional(str(product), parameters=functional.parameters)
        
        return functional
    
    @match_class(LincombOperator)
    def action_LincombOperator(self, op):
        return op.with_(coefficients=[self.apply(c) for c in op.coefficients])

    @match_class(Model, Operator)
    def action_recurse(self, op):
        return self.replace_children(op)

    @match_always
    def action_generic(self, expr):
        return expr

### simplify_functionals -


In [None]:
#| export
def simplify_functionals(obj):
    return SimplifyFunctionalRules().apply(obj)

In [None]:
simplify_functionals(functional)

ExpressionParameterFunctional('-3890540.14*sqrt(IntensitySI1)', {IntensitySI1: 1})

In [None]:
#| hide
test_eq(_, ExpressionParameterFunctional('-3890540.14*sqrt(IntensitySI1)', {'IntensitySI1': 1}))

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| hide
# #| export 
# # Allow 2D parameter values in `Mu` for broadcasting purposes.
# class Mu(Mu):
#     def __new__(cls, *args, **kwargs):
#         raw_values = dict(*args, **kwargs)
#         values_for_t = {}
#         for k, v in sorted(raw_values.items()):
#             assert isinstance(k, str)
#             if callable(v):
#                 # note: We can't import Function globally due to circular dependencies, so
#                 # we import it locally in this branch to avoid executing the import statement
#                 # each time a Mu is created (which would make instantiation of simple Mus without
#                 # time dependency significantly more expensive).
#                 # from pymor.analyticalproblems.functions import Function
#                 assert k != 't'
#                 assert isinstance(v, Function) and v.dim_domain == 1 and len(v.shape_range) == 1
#                 vv = v(raw_values.get('t', 0))
#             else:
#                 vv = np.array(v, copy=False, ndmin=1)
#                 # assert vv.ndim == 1
#                 assert k != 't' or len(vv) == 1
#             assert not vv.setflags(write=False)
#             values_for_t[k] = vv

#         mu = FrozenDict.__new__(cls, values_for_t)
#         mu._raw_values = raw_values
#         return mu

# #| export 
# @patch
# def broadcast(
#     self:Mu, 
#     transpose=False # Transpose arrays of values before broadcasting
# )->np.ndarray:
#     """Broadcast all parameter values together to create 2D array of `Mu` objects with scalar parameters."""
#     values = np.array(np.broadcast_arrays(*self.values())).transpose(1, 2, 0)
#     if transpose:
#         values = values.transpose(1, 0, 2)
#     return array_map(lambda v: Mu(zip(self.keys(), v)), values, 2)

# mu = Mu(A=[[.5, .8, 1.2]], B=.4)

# mu.broadcast()

# mu.broadcast(transpose=True)

# #| export
# @patch
# def scalar_parameters(self:Mu):
#     return Mu({k: v for k, v in self.items() if v.size == 1})