Skip to content

Commit

Permalink
Remove _RequireAttrsABCMeta metaclass and replace with simple check (#…
Browse files Browse the repository at this point in the history
…409)

* Remove _RequireAttrsABCMeta metaclass and replace with simple check

* make BaseLearner a ABC
  • Loading branch information
basnijholt committed Apr 29, 2023
1 parent b16f0e5 commit 82ed0a4
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 20 deletions.
1 change: 1 addition & 0 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.min_npoints = max(min_npoints, 2)
self.sum_f: Real = 0.0
self.sum_f_sq: Real = 0.0
self._check_required_attributes()

def new(self) -> AverageLearner:
"""Create a copy of `~adaptive.AverageLearner` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/average_learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self._distances: dict[Real, float] = decreasing_dict()
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
self.rescaled_error: dict[Real, float] = decreasing_dict()
self._check_required_attributes()

def new(self) -> AverageLearner1D:
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
)

self.strategy: STRATEGY_TYPE = strategy
self._check_required_attributes()

def new(self) -> BalancingLearner:
"""Create a new `BalancingLearner` with the same parameters."""
Expand Down
17 changes: 15 additions & 2 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cloudpickle

from adaptive.utils import _RequireAttrsABCMeta, load, save
from adaptive.utils import load, save


def uses_nth_neighbors(n: int):
Expand Down Expand Up @@ -60,7 +60,7 @@ def _wrapped(loss_per_interval):
return _wrapped


class BaseLearner(metaclass=_RequireAttrsABCMeta):
class BaseLearner(abc.ABC):
"""Base class for algorithms for learning a function 'f: X → Y'.
Attributes
Expand Down Expand Up @@ -198,3 +198,16 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = cloudpickle.loads(state)

def _check_required_attributes(self):
for name, type_ in self.__annotations__.items():
try:
x = getattr(self, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
1 change: 1 addition & 0 deletions adaptive/learner/data_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, learner: BaseLearner, arg_picker: Callable) -> None:
self.extra_data = OrderedDict()
self.function = learner.function
self.arg_picker = arg_picker
self._check_required_attributes()

def new(self) -> DataSaver:
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> N
ival = _Interval.make_first(*self.bounds)
self.add_ival(ival)
self.first_ival = ival
self._check_required_attributes()

def new(self) -> IntegratorLearner:
"""Create a copy of `~adaptive.Learner2D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
self.__missing_bounds = set(self.bounds) # cache of missing bounds

self._vdim: int | None = None
self._check_required_attributes()

def new(self) -> Learner1D:
"""Create a copy of `~adaptive.Learner1D` without the data."""
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def __init__(
self._ip = self._ip_combined = None

self.stack_size = 10
self._check_required_attributes()

def new(self) -> Learner2D:
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
Expand Down
2 changes: 2 additions & 0 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def __init__(self, func, bounds, loss_per_simplex=None):
# _pop_highest_existing_simplex
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)

self._check_required_attributes()

def new(self) -> LearnerND:
"""Create a new learner with the same function and bounds."""
return LearnerND(self.function, self.bounds, self.loss_per_simplex)
Expand Down
1 change: 1 addition & 0 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, function, sequence):
self.sequence = copy(sequence)
self.data = SortedDict()
self.pending_points = set()
self._check_required_attributes()

def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
Expand Down
18 changes: 0 additions & 18 deletions adaptive/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import abc
import concurrent.futures as concurrent
import functools
import gzip
Expand Down Expand Up @@ -90,23 +89,6 @@ def decorator(method):
return decorator


class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
for name, type_ in obj.__annotations__.items():
try:
x = getattr(obj, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
return obj


def _default_parameters(function, function_prefix: str = "function."):
sig = inspect.signature(function)
defaults = {
Expand Down

0 comments on commit 82ed0a4

Please sign in to comment.