Skip to content

Commit

Permalink
Merge pull request #222 from python-adaptive/attr_checking_base_class
Browse files Browse the repository at this point in the history
add _RequireAttrsABCMeta and make the BaseLearner use it
  • Loading branch information
basnijholt committed Oct 17, 2019
2 parents 187f88f + 1474a5d commit ede9582
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 29 deletions.
18 changes: 18 additions & 0 deletions adaptive/learner/balancing_learner.py
Expand Up @@ -90,6 +90,24 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):

self.strategy = strategy

@property
def data(self):
data = {}
for i, l in enumerate(self.learners):
data.update({(i, p): v for p, v in l.data.items()})
return data

@property
def pending_points(self):
pending_points = set()
for i, l in enumerate(self.learners):
pending_points.update({(i, p) for p in l.pending_points})
return pending_points

@property
def npoints(self):
return sum(l.npoints for l in self.learners)

@property
def strategy(self):
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
Expand Down
21 changes: 11 additions & 10 deletions adaptive/learner/base_learner.py
Expand Up @@ -4,7 +4,7 @@
from contextlib import suppress
from copy import deepcopy

from adaptive.utils import load, save
from adaptive.utils import load, save, _RequireAttrsABCMeta


def uses_nth_neighbors(n):
Expand Down Expand Up @@ -61,30 +61,31 @@ def _wrapped(loss_per_interval):
return _wrapped


class BaseLearner(metaclass=abc.ABCMeta):
class BaseLearner(metaclass=_RequireAttrsABCMeta):
"""Base class for algorithms for learning a function 'f: X → Y'.
Attributes
----------
function : callable: X → Y
The function to learn.
The function to learn. A subclass of BaseLearner might modify
the user's supplied function.
data : dict: X → Y
`function` evaluated at certain points.
The values can be 'None', which indicates that the point
will be evaluated, but that we do not have the result yet.
npoints : int, optional
The number of evaluated points that have been added to the learner.
Subclasses do not *have* to implement this attribute.
pending_points : set, optional
pending_points : set
Points that have been requested but have not been evaluated yet.
Subclasses do not *have* to implement this attribute.
npoints : int
The number of evaluated points that have been added to the learner.
Notes
-----
Subclasses may define a ``plot`` method that takes no parameters
and returns a holoviews plot.
"""

data: dict
npoints: int
pending_points: set

def tell(self, x, y):
"""Tell the learner about a single value.
Expand Down
34 changes: 16 additions & 18 deletions adaptive/learner/integrator_learner.py
Expand Up @@ -100,7 +100,7 @@ class _Interval:
The parent interval.
children : list of `_Interval`s
The intervals resulting from a split.
done_points : dict
data : dict
A dictionary with the x-values and y-values: `{x1: y1, x2: y2 ...}`.
done : bool
The integral and the error for the interval has been calculated.
Expand Down Expand Up @@ -133,15 +133,15 @@ class _Interval:
"ndiv",
"parent",
"children",
"done_points",
"data",
"done_leaves",
"depth_complete",
"removed",
]

def __init__(self, a, b, depth, rdepth):
self.children = []
self.done_points = {}
self.data = {}
self.a = a
self.b = b
self.depth = depth
Expand Down Expand Up @@ -172,9 +172,9 @@ def T(self):

def refinement_complete(self, depth):
"""The interval has all the y-values to calculate the intergral."""
if len(self.done_points) < ns[depth]:
if len(self.data) < ns[depth]:
return False
return all(p in self.done_points for p in self.points(depth))
return all(p in self.data for p in self.points(depth))

def points(self, depth=None):
if depth is None:
Expand Down Expand Up @@ -255,7 +255,7 @@ def complete_process(self, depth):
assert self.depth_complete is None or self.depth_complete == depth - 1
self.depth_complete = depth

fx = [self.done_points[k] for k in self.points(depth)]
fx = [self.data[k] for k in self.points(depth)]
self.fx = np.array(fx)
force_split = False # This may change when refining

Expand Down Expand Up @@ -375,7 +375,7 @@ def __init__(self, function, bounds, tol):
self.tol = tol
self.max_ivals = 1000
self.priority_split = []
self.done_points = {}
self.data = {}
self.pending_points = set()
self._stack = []
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
Expand All @@ -391,13 +391,13 @@ def approximating_intervals(self):
def tell(self, point, value):
if point not in self.x_mapping:
raise ValueError(f"Point {point} doesn't belong to any interval")
self.done_points[point] = value
self.data[point] = value
self.pending_points.discard(point)

# Select the intervals that have this point
ivals = self.x_mapping[point]
for ival in ivals:
ival.done_points[point] = value
ival.data[point] = value

if ival.depth_complete is None:
from_depth = 0 if ival.parent is not None else 2
Expand Down Expand Up @@ -438,8 +438,8 @@ def add_ival(self, ival):
for x in ival.points():
# Update the mappings
self.x_mapping[x].add(ival)
if x in self.done_points:
self.tell(x, self.done_points[x])
if x in self.data:
self.tell(x, self.data[x])
elif x not in self.pending_points:
self.pending_points.add(x)
self._stack.append(x)
Expand Down Expand Up @@ -518,7 +518,7 @@ def _fill_stack(self):
@property
def npoints(self):
"""Number of evaluated points."""
return len(self.done_points)
return len(self.data)

@property
def igral(self):
Expand Down Expand Up @@ -552,11 +552,9 @@ def loss(self, real=True):
def plot(self):
hv = ensure_holoviews()
ivals = sorted(self.ivals, key=attrgetter("a"))
if not self.done_points:
if not self.data:
return hv.Path([])
xs, ys = zip(
*[(x, y) for ival in ivals for x, y in sorted(ival.done_points.items())]
)
xs, ys = zip(*[(x, y) for ival in ivals for x, y in sorted(ival.data.items())])
return hv.Path((xs, ys))

def _get_data(self):
Expand All @@ -565,7 +563,7 @@ def _get_data(self):

return (
self.priority_split,
self.done_points,
self.data,
self.pending_points,
self._stack,
x_mapping,
Expand All @@ -574,7 +572,7 @@ def _get_data(self):
)

def _set_data(self, data):
self.priority_split, self.done_points, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival = (
self.priority_split, self.data, self.pending_points, self._stack, x_mapping, self.ivals, self.first_ival = (
data
)

Expand Down
2 changes: 2 additions & 0 deletions adaptive/learner/skopt_learner.py
Expand Up @@ -26,10 +26,12 @@ class SKOptLearner(Optimizer, BaseLearner):
def __init__(self, function, **kwargs):
self.function = function
self.pending_points = set()
self.data = {}
super().__init__(**kwargs)

def tell(self, x, y, fit=True):
self.pending_points.discard(x)
self.data[x] = y
super().tell([x], y, fit)

def tell_pending(self, x):
Expand Down
2 changes: 1 addition & 1 deletion adaptive/tests/test_cquad.py
Expand Up @@ -188,7 +188,7 @@ def test_tell_in_random_order(first_add_33=False):
learners.append(learner)

# Check whether the points of the learners are identical
assert set(learners[0].done_points) == set(learners[1].done_points)
assert set(learners[0].data) == set(learners[1].data)

# Test whether approximating_intervals gives a complete set of intervals
for learner in learners:
Expand Down
18 changes: 18 additions & 0 deletions adaptive/utils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import abc
import functools
import gzip
import os
Expand Down Expand Up @@ -67,3 +68,20 @@ def decorator(method):
return functools.wraps(other)(method)

return decorator


class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
for name, type_ in obj.__annotations__.items():
try:
x = getattr(obj, name)
except AttributeError:
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
return obj

0 comments on commit ede9582

Please sign in to comment.