Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __auto_init method and use it where applicable #732

Merged
merged 2 commits into from Jul 12, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -223,8 +223,7 @@ class CollectOperatorRangeRules(RuleTable):

def __init__(self, source, image, extends):
super().__init__(use_caching=True)
self.source, self.image, self.extends = \
source, image, extends
self.__auto_init(locals())

@match_generic(lambda op: op.linear and not op.parametric)
def action_apply_operator(self, op):
@@ -53,8 +53,7 @@ def assemble_lincomb(operators, coefficients, solver_options=None, name=None):
class AssembleLincombRules(RuleTable):
def __init__(self, coefficients, solver_options, name):
super().__init__(use_caching=False)
self.coefficients, self.solver_options, self.name \
= coefficients, solver_options, name
self.__auto_init(locals())

@match_class_any(ZeroOperator)
def action_ZeroOperator(self, ops):
@@ -71,8 +71,7 @@ class ProjectRules(RuleTable):

def __init__(self, range_basis, source_basis, product):
super().__init__(use_caching=True)
self.range_basis, self.source_basis, self.product = \
range_basis, source_basis, product
self.__auto_init(locals())

@match_class(ZeroOperator)
def action_ZeroOperator(self, op):
@@ -244,7 +243,7 @@ class ProjectToSubbasisRules(RuleTable):

def __init__(self, dim_range, dim_source):
super().__init__(use_caching=True)
self.dim_range, self.dim_source = dim_range, dim_source
self.__auto_init(locals())

@match_class(LincombOperator)
def action_recurse(self, op):
@@ -89,8 +89,7 @@ class ImplicitEulerTimeStepper(TimeStepperInterface):
"""

def __init__(self, nt, solver_options='operator'):
self.nt = nt
self.solver_options = solver_options
self.__auto_init(locals())

def solve(self, initial_time, end_time, initial_data, operator, rhs=None, mass=None, mu=None, num_values=None):
return implicit_euler(operator, rhs, mass, initial_data, initial_time, end_time, self.nt, mu, num_values,
@@ -111,7 +110,7 @@ class ExplicitEulerTimeStepper(TimeStepperInterface):
"""

def __init__(self, nt):
self.nt = nt
self.__auto_init(locals())

def solve(self, initial_time, end_time, initial_data, operator, rhs=None, mass=None, mu=None, num_values=None):
if mass is not None:
@@ -43,7 +43,7 @@ class ToMatrixRules(RuleTable):

def __init__(self, format, mu):
super().__init__()
self.format, self.mu = format, mu
self.__auto_init(locals())

@match_class(NumpyMatrixOperator)
def action_NumpyMatrixOperator(self, op):
@@ -112,18 +112,6 @@ def __init__(self, domain,
or all(isinstance(v, tuple) and len(v) == 2 and v[0] in ('l2', 'l2_boundary')
and v[1].dim_domain == domain.dim and v[1].shape_range == () for v in functionals.values()))

self.domain = domain
self.rhs = rhs
self.diffusion = diffusion
self.advection = advection
self.nonlinear_advection = nonlinear_advection
self.nonlinear_advection_derivative = nonlinear_advection_derivative
self.reaction = reaction
self.nonlinear_reaction = nonlinear_reaction
self.nonlinear_reaction_derivative = nonlinear_reaction_derivative
self.dirichlet_data = dirichlet_data
self.neumann_data = neumann_data
self.robin_data = robin_data
self.functionals = FrozenDict(functionals) if functionals is not None else None
self.parameter_space = parameter_space
self.name = name
functionals = FrozenDict(functionals) if functionals is not None else None

self.__auto_init(locals())
@@ -38,12 +38,8 @@ class InstationaryProblem(ImmutableInterface):
"""

def __init__(self, stationary_part, initial_data, T=1., parameter_space=None, name=None):

self.stationary_part = stationary_part
self.initial_data = initial_data
self.T = T
self.parameter_space = parameter_space or stationary_part.parameter_space
self.name = name or ('instationary_' + stationary_part.name)
name = name or ('instationary_' + stationary_part.name)
self.__auto_init(locals())

def with_stationary_part(self, **kwargs):
return self.with_(stationary_part=self.stationary_part.with_(**kwargs))
@@ -117,8 +117,7 @@ def __neg__(self):
class FenicsVectorSpace(ListVectorSpace):

def __init__(self, V, id='STATE'):
self.V = V
self.id = id
self.__auto_init(locals())

@property
def dim(self):
@@ -156,13 +155,9 @@ class FenicsMatrixOperator(OperatorBase):

def __init__(self, matrix, source_space, range_space, solver_options=None, name=None):
assert matrix.rank() == 2
self.source_space = source_space
self.range_space = range_space
self.__auto_init(locals())
self.source = FenicsVectorSpace(source_space)
self.range = FenicsVectorSpace(range_space)
self.matrix = matrix
self.solver_options = solver_options
self.name = name

def apply(self, U, mu=None):
assert U in self.source
@@ -66,8 +66,7 @@ def amax(self):
class NGSolveVectorSpace(ListVectorSpace):

def __init__(self, V, id='STATE'):
self.V = V
self.id = id
self.__auto_init(locals())

def __eq__(self, other):
return type(other) is NGSolveVectorSpace and self.V == other.V and self.id == other.id
@@ -109,11 +108,7 @@ class NGSolveMatrixOperator(OperatorBase):
linear = True

def __init__(self, matrix, range, source, solver_options=None, name=None):
self.range = range
self.source = source
self.matrix = matrix
self.solver_options = solver_options
self.name = name
self.__auto_init(locals())

def apply(self, U, mu=None):
assert U in self.source
@@ -163,8 +158,7 @@ class NGSolveVisualizer(ImmutableInterface):
"""Visualize an NGSolve grid function."""

def __init__(self, mesh, fespace):
self.mesh = mesh
self.fespace = fespace
self.__auto_init(locals())
self.space = NGSolveVectorSpace(fespace)

def visualize(self, U, m, legend=None, separate_colorbars=True, block=True):
@@ -146,6 +146,25 @@ def __new__(cls, classname, bases, classdict):
base_doc = doc
item.__doc__ = base_doc

def __auto_init(self, locals_):
"""Automatically assign __init__ arguments.

This method is used in __init__ to automatically assign __init__ arguments to equally
named object attributes. The values are provided by the `locals_` dict. Usually,
`__auto_init` is called as::

self.__auto_init(locals())

where `locals()` returns a dictionary of all local variables in the current scope.
Only attributes which have not already been set by the user are initialized by
`__auto_init`.
"""
for arg in c._init_arguments:
if arg not in self.__dict__:
setattr(self, arg, locals_[arg])

classdict[f'_{classname}__auto_init'] = __auto_init

c = abc.ABCMeta.__new__(cls, classname, bases, classdict)

# getargspec is deprecated and does not work with keyword only args
@@ -44,12 +44,9 @@ def __init__(self, domain=([0, 0], [1, 1]), left='dirichlet', right='dirichlet',
for bt in (left, right, top, bottom):
if bt is not None and bt not in KNOWN_BOUNDARY_TYPES:
self.logger.warning(f'Unknown boundary type: {bt}')
domain = np.array(domain)
self.__auto_init(locals())
self.boundary_types = frozenset({left, right, top, bottom})
self.left = left
self.right = right
self.top = top
self.bottom = bottom
self.domain = np.array(domain)

@property
def lower_left(self):
@@ -106,10 +103,9 @@ def __init__(self, domain=([0, 0], [1, 1]), top='dirichlet', bottom='dirichlet')
for bt in (top, bottom):
if bt is not None and bt not in KNOWN_BOUNDARY_TYPES:
self.logger.warning(f'Unknown boundary type: {bt}')
domain = np.array(domain)
self.__auto_init(locals())
self.boundary_types = frozenset({top, bottom})
self.top = top
self.bottom = bottom
self.domain = np.array(domain)

@property
def lower_left(self):
@@ -212,10 +208,9 @@ def __init__(self, domain=(0, 1), left='dirichlet', right='dirichlet'):
for bt in (left, right):
if bt is not None and bt not in KNOWN_BOUNDARY_TYPES:
self.logger.warning(f'Unknown boundary type: {bt}')
domain = np.array(domain)
self.__auto_init(locals())
self.boundary_types = frozenset({left, right})
self.left = left
self.right = right
self.domain = np.array(domain)

@property
def width(self):
@@ -34,11 +34,9 @@ class PolygonalDomain(DomainDescriptionInterface):

def __init__(self, points, boundary_types, holes=None):
holes = holes or []
self.points = points
self.holes = holes

if isinstance(boundary_types, dict):
self.boundary_types = boundary_types
pass
# if the boundary types are not given as a dict, try to evaluate at the edge centers to get a dict.
else:
points = [points]
@@ -52,12 +50,14 @@ def __init__(self, points, boundary_types, holes=None):
for p0, p1 in zip(ps, ps_d)]
# evaluate the boundary at the edge centers and save the boundary types together with the
# corresponding edge id.
self.boundary_types = dict(zip([boundary_types(centers)], [list(range(1, len(centers)+1))]))
boundary_types = dict(zip([boundary_types(centers)], [list(range(1, len(centers)+1))]))

for bt in self.boundary_types.keys():
for bt in boundary_types.keys():
if bt is not None and bt not in KNOWN_BOUNDARY_TYPES:
self.logger.warning(f'Unknown boundary type: {bt}')

self.__auto_init(locals())


class CircularSectorDomain(PolygonalDomain):
"""Describes a circular sector domain of variable radius.
@@ -86,29 +86,25 @@ class CircularSectorDomain(PolygonalDomain):
"""

def __init__(self, angle, radius, arc='dirichlet', radii='dirichlet', num_points=100):
self.angle = angle
self.radius = radius
self.arc = arc
self.radii = radii
self.num_points = num_points
assert (0 < self.angle) and (self.angle < 2*np.pi)
assert self.radius > 0
assert self.num_points > 0
assert (0 < angle) and (angle < 2*np.pi)
assert radius > 0
assert num_points > 0

points = [[0., 0.]]
points.extend([[self.radius*np.cos(t), self.radius*np.sin(t)] for t in
np.linspace(start=0, stop=angle, num=self.num_points, endpoint=True)])
points.extend([[radius*np.cos(t), radius*np.sin(t)] for t in
np.linspace(start=0, stop=angle, num=num_points, endpoint=True)])

if self.arc == self.radii:
boundary_types = {self.arc: list(range(1, len(points)+1))}
if arc == radii:
boundary_types = {arc: list(range(1, len(points)+1))}
else:
boundary_types = {self.arc: list(range(2, len(points)))}
boundary_types.update({self.radii: [1, len(points)]})
boundary_types = {arc: list(range(2, len(points)))}
boundary_types.update({radii: [1, len(points)]})

if None in boundary_types:
del boundary_types[None]

super().__init__(points, boundary_types)
self.__auto_init(locals())


class DiscDomain(PolygonalDomain):
@@ -131,14 +127,12 @@ class DiscDomain(PolygonalDomain):
"""

def __init__(self, radius, boundary='dirichlet', num_points=100):
self.radius = radius
self.boundary = boundary
self.num_points = num_points
assert self.radius > 0
assert self.num_points > 0
assert radius > 0
assert num_points > 0

points = [[self.radius*np.cos(t), self.radius*np.sin(t)] for t in
points = [[radius*np.cos(t), radius*np.sin(t)] for t in
np.linspace(start=0, stop=2*np.pi, num=num_points, endpoint=False)]
boundary_types = {} if self.boundary is None else {boundary: list(range(1, len(points)+1))}
boundary_types = {} if boundary is None else {boundary: list(range(1, len(points)+1))}

super().__init__(points, boundary_types)
self.__auto_init(locals())
@@ -66,10 +66,8 @@ def __init__(self, value=np.array(1.0), dim_domain=1, name=None):
assert dim_domain > 0
assert isinstance(value, (Number, np.ndarray))
value = np.array(value)
self.value = value
self.dim_domain = dim_domain
self.__auto_init(locals())
self.shape_range = value.shape
self.name = name

def __str__(self):
return f'{self.name}: x -> {self.value}'
@@ -114,12 +112,11 @@ class GenericFunction(FunctionBase):
def __init__(self, mapping, dim_domain=1, shape_range=(), parameter_type=None, name=None):
assert dim_domain > 0
assert isinstance(shape_range, (Number, tuple))
self.dim_domain = dim_domain
self.shape_range = shape_range if isinstance(shape_range, tuple) else (shape_range,)
self.name = name
self.mapping = mapping
if not isinstance(shape_range, tuple):
shape_range = (shape_range,)
if parameter_type is not None:
self.build_parameter_type(parameter_type)
self.__auto_init(locals())

def __str__(self):
return f'{self.name}: x -> {self.mapping}'
@@ -173,11 +170,11 @@ class ExpressionFunction(GenericFunction):
functions = ExpressionParameterFunctional.functions

def __init__(self, expression, dim_domain=1, shape_range=(), parameter_type=None, values=None, name=None):
self.expression = expression
self.values = values or {}
values = values or {}
code = compile(expression, '<expression>', 'eval')
super().__init__(lambda x, mu={}: eval(code, dict(self.functions, **self.values), dict(mu, x=x, mu=mu)),
super().__init__(lambda x, mu={}: eval(code, dict(self.functions, **values), dict(mu, x=x, mu=mu)),
dim_domain, shape_range, parameter_type, name)
self.__auto_init(locals())

def __reduce__(self):
return (ExpressionFunction,
@@ -214,13 +211,11 @@ def __init__(self, functions, coefficients, name=None):
assert all(isinstance(c, (ParameterFunctionalInterface, Number)) for c in coefficients)
assert all(f.dim_domain == functions[0].dim_domain for f in functions[1:])
assert all(f.shape_range == functions[0].shape_range for f in functions[1:])
self.dim_domain = functions[0].dim_domain
self.shape_range = functions[0].shape_range
self.functions = functions
self.coefficients = coefficients
self.name = name
self.__auto_init(locals())
self.build_parameter_type(*chain(functions,
(f for f in coefficients if isinstance(f, ParameterFunctionalInterface))))
self.dim_domain = functions[0].dim_domain
self.shape_range = functions[0].shape_range

def evaluate_coefficients(self, mu):
"""Compute the linear coefficients for a given |Parameter| `mu`."""
@@ -253,11 +248,10 @@ def __init__(self, functions, name=None):
assert all(isinstance(f, FunctionInterface) for f in functions)
assert all(f.dim_domain == functions[0].dim_domain for f in functions[1:])
assert all(f.shape_range == functions[0].shape_range for f in functions[1:])
self.__auto_init(locals())
self.build_parameter_type(*functions)
self.dim_domain = functions[0].dim_domain
self.shape_range = functions[0].shape_range
self.functions = functions
self.name = name
self.build_parameter_type(*functions)

def evaluate(self, x, mu=None):
mu = self.parse_parameter(mu)