-
Notifications
You must be signed in to change notification settings - Fork 58
/
average_learner.py
146 lines (121 loc) · 4.16 KB
/
average_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from math import sqrt
import numpy as np
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews
from adaptive.utils import cache_latest
class AverageLearner(BaseLearner):
"""A naive implementation of adaptive computing of averages.
The learned function must depend on an integer input variable that
represents the source of randomness.
Parameters
----------
atol : float
Desired absolute tolerance.
rtol : float
Desired relative tolerance.
Attributes
----------
data : dict
Sampled points and values.
pending_points : set
Points that still have to be evaluated.
npoints : int
Number of evaluated points.
"""
def __init__(self, function, atol=None, rtol=None):
if atol is None and rtol is None:
raise Exception("At least one of `atol` and `rtol` should be set.")
if atol is None:
atol = np.inf
if rtol is None:
rtol = np.inf
self.data = {}
self.pending_points = set()
self.function = function
self.atol = atol
self.rtol = rtol
self.npoints = 0
self.sum_f = 0
self.sum_f_sq = 0
@property
def n_requested(self):
return self.npoints + len(self.pending_points)
def ask(self, n, tell_pending=True):
points = list(range(self.n_requested, self.n_requested + n))
if any(p in self.data or p in self.pending_points for p in points):
# This means some of the points `< self.n_requested` do not exist.
points = list(
set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points)
)[:n]
loss_improvements = [self._loss_improvement(n) / n] * n
if tell_pending:
for p in points:
self.tell_pending(p)
return points, loss_improvements
def tell(self, n, value):
if n in self.data:
# The point has already been added before.
return
self.data[n] = value
self.pending_points.discard(n)
self.sum_f += value
self.sum_f_sq += value ** 2
self.npoints += 1
def tell_pending(self, n):
self.pending_points.add(n)
@property
def mean(self):
"""The average of all values in `data`."""
return self.sum_f / self.npoints
@property
def std(self):
"""The corrected sample standard deviation of the values
in `data`."""
n = self.npoints
if n < 2:
return np.inf
numerator = self.sum_f_sq - n * self.mean ** 2
if numerator < 0:
# in this case the numerator ~ -1e-15
return 0
return sqrt(numerator / (n - 1))
@cache_latest
def loss(self, real=True, *, n=None):
if n is None:
n = self.npoints if real else self.n_requested
else:
n = n
if n < 2:
return np.inf
standard_error = self.std / sqrt(n)
return max(
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
)
def _loss_improvement(self, n):
loss = self.loss()
if np.isfinite(loss):
return loss - self.loss(n=self.npoints + n)
else:
return np.inf
def remove_unfinished(self):
"""Remove uncomputed data from the learner."""
self.pending_points = set()
def plot(self):
"""Returns a histogram of the evaluated data.
Returns
-------
holoviews.element.Histogram
A histogram of the evaluated data."""
hv = ensure_holoviews()
vals = [v for v in self.data.values() if v is not None]
if not vals:
return hv.Histogram([[], []])
num_bins = int(max(5, sqrt(self.npoints)))
vals = hv.Points(vals)
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
def _get_data(self):
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
def _set_data(self, data):
self.data, self.npoints, self.sum_f, self.sum_f_sq = data