Skip to content

Commit

Permalink
Merge branch 'cache_loss' into 'master'
Browse files Browse the repository at this point in the history
Cache loss and display it in the live_info widget

See merge request qt/adaptive!117
  • Loading branch information
basnijholt committed Oct 11, 2018
2 parents b5b81ac + d5774ff commit 953ff84
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 18 deletions.
4 changes: 3 additions & 1 deletion adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import numpy as np

from ..notebook_integration import ensure_holoviews
from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest


class AverageLearner(BaseLearner):
Expand Down Expand Up @@ -90,6 +91,7 @@ def std(self):
return np.inf
return sqrt((self.sum_f_sq - n * self.mean**2) / (n - 1))

@cache_latest
def loss(self, real=True, *, n=None):
if n is None:
n = self.npoints if real else self.n_requested
Expand Down
3 changes: 2 additions & 1 deletion adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews
from ..utils import restore, named_product
from ..utils import cache_latest, named_product, restore


def dispatch(child_functions, arg):
Expand Down Expand Up @@ -116,6 +116,7 @@ def losses(self, real=True):

return losses

@cache_latest
def loss(self, real=True):
losses = self.losses(real)
return max(losses)
Expand Down
1 change: 0 additions & 1 deletion adaptive/learner/data_saver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-

from collections import OrderedDict
import functools

Expand Down
3 changes: 2 additions & 1 deletion adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ndiv_max, min_sep, eps, xi, V_inv,
Vcond, alpha, gamma)
from ..notebook_integration import ensure_holoviews
from ..utils import restore
from ..utils import cache_latest, restore


def _downdate(c, nans, depth):
Expand Down Expand Up @@ -514,6 +514,7 @@ def done(self):
or (err - err_excess < abs(igral) * self.tol < err_excess)
or not self.ivals)

@cache_latest
def loss(self, real=True):
return abs(abs(self.igral) * self.tol - self.err)

Expand Down
4 changes: 3 additions & 1 deletion adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import numpy as np
import sortedcontainers

from ..notebook_integration import ensure_holoviews
from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest


def uniform_loss(interval, scale, function_values):
Expand Down Expand Up @@ -156,6 +157,7 @@ def vdim(self):
def npoints(self):
return len(self.data)

@cache_latest
def loss(self, real=True):
losses = self.losses if real else self.losses_combined
return max(losses.values()) if len(losses) > 0 else float('inf')
Expand Down
8 changes: 4 additions & 4 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import numpy as np
from scipy import interpolate

from ..notebook_integration import ensure_holoviews
from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest


# Learner2D and helper functions.
Expand Down Expand Up @@ -267,7 +268,6 @@ def __init__(self, function, bounds, loss_per_triangle=None):
self._stack.update({p: np.inf for p in self._bounds_points})
self.function = function
self._ip = self._ip_combined = None
self._loss = np.inf

self.stack_size = 10

Expand Down Expand Up @@ -438,13 +438,13 @@ def ask(self, n, tell_pending=True):

return points[:n], loss_improvements[:n]

@cache_latest
def loss(self, real=True):
if not self.bounds_are_done:
return np.inf
ip = self.ip() if real else self.ip_combined()
losses = self.loss_per_triangle(ip)
self._loss = losses.max()
return self._loss
return losses.max()

def remove_unfinished(self):
self.pending_points = set()
Expand Down
3 changes: 2 additions & 1 deletion adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..notebook_integration import ensure_holoviews
from .triangulation import (Triangulation, point_in_simplex,
circumsphere, simplex_volume_in_embedding)
from ..utils import restore
from ..utils import restore, cache_latest


def volume(simplex, ys=None):
Expand Down Expand Up @@ -452,6 +452,7 @@ def losses(self):

return self._losses

@cache_latest
def loss(self, real=True):
losses = self.losses() # XXX: compute pending loss if real == False
return max(losses.values()) if losses else float('inf')
Expand Down
8 changes: 4 additions & 4 deletions adaptive/learner/skopt_learner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-

import numpy as np
from skopt import Optimizer

from ..notebook_integration import ensure_holoviews
from .base_learner import BaseLearner

from skopt import Optimizer
from ..notebook_integration import ensure_holoviews
from ..utils import restore, cache_latest


class SKOptLearner(Optimizer, BaseLearner):
Expand Down Expand Up @@ -38,6 +37,7 @@ def tell_pending(self, x):
def remove_unfinished(self):
pass

@cache_latest
def loss(self, real=True):
if not self.models:
return np.inf
Expand Down
3 changes: 3 additions & 0 deletions adaptive/notebook_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def _info_html(runner):
with suppress(Exception):
info.append(('# of points', runner.learner.npoints))

with suppress(Exception):
info.append(('latest loss', f'{runner.learner._cache["loss"]:.3f}'))

template = '<dt>{}</dt><dd>{}</dd>'
table = '\n'.join(template.format(k, v) for k, v in info)

Expand Down
14 changes: 14 additions & 0 deletions adaptive/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from contextlib import contextmanager
from functools import wraps
from itertools import product
import time

Expand All @@ -24,3 +25,16 @@ def restore(*learners):
finally:
for state, learner in zip(states, learners):
learner.__setstate__(state)


def cache_latest(f):
"""Cache the latest return value of the function and add it
as 'self._cache[f.__name__]'."""
@wraps(f)
def wrapper(*args, **kwargs):
self = args[0]
if not hasattr(self, '_cache'):
self._cache = {}
self._cache[f.__name__] = f(*args, **kwargs)
return self._cache[f.__name__]
return wrapper
5 changes: 1 addition & 4 deletions learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,7 @@
"source": [
"def plot(learner):\n",
" plot = learner.plot(tri_alpha=0.2)\n",
" title = f'loss={learner._loss:.3f}, n_points={learner.npoints}'\n",
" return (plot.Image\n",
" + plot.EdgePaths.I.opts(plot=dict(title_format=title))\n",
" + plot)\n",
" return plot.Image + plot.EdgePaths.I + plot\n",
"\n",
"runner.live_plot(plotter=plot, update_interval=0.1)"
]
Expand Down

0 comments on commit 953ff84

Please sign in to comment.