Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for neighbours in loss computation in LearnerND #185

Merged
merged 14 commits into from May 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 52 additions & 0 deletions adaptive/learner/base_learner.py
Expand Up @@ -7,6 +7,58 @@
from adaptive.utils import save, load


def uses_nth_neighbors(n):
"""Decorator to specify how many neighboring intervals the loss function uses.

Wraps loss functions to indicate that they expect intervals together
with ``n`` nearest neighbors

The loss function will then receive the data of the N nearest neighbors
(``nth_neighbors``) aling with the data of the interval itself in a dict.
The `~adaptive.Learner1D` will also make sure that the loss is updated
whenever one of the ``nth_neighbors`` changes.

Examples
--------

The next function is a part of the `curvature_loss_function` function.

>>> @uses_nth_neighbors(1)
... def triangle_loss(xs, ys):
... xs = [x for x in xs if x is not None]
... ys = [y for y in ys if y is not None]
...
... if len(xs) == 2: # we do not have enough points for a triangle
... return xs[1] - xs[0]
...
... N = len(xs) - 2 # number of constructed triangles
... if isinstance(ys[0], Iterable):
... pts = [(x, *y) for x, y in zip(xs, ys)]
... vol = simplex_volume_in_embedding
... else:
... pts = [(x, y) for x, y in zip(xs, ys)]
... vol = volume
... return sum(vol(pts[i:i+3]) for i in range(N)) / N

Or you may define a loss that favours the (local) minima of a function,
assuming that you know your function will have a single float as output.

>>> @uses_nth_neighbors(1)
... def local_minima_resolving_loss(xs, ys):
... dx = xs[2] - xs[1] # the width of the interval of interest
...
... if not ((ys[0] is not None and ys[0] > ys[1])
... or (ys[3] is not None and ys[3] > ys[2])):
... return loss * 100
...
... return loss
"""
def _wrapped(loss_per_interval):
loss_per_interval.nth_neighbors = n
return loss_per_interval
return _wrapped


class BaseLearner(metaclass=abc.ABCMeta):
"""Base class for algorithms for learning a function 'f: X → Y'.

Expand Down
54 changes: 1 addition & 53 deletions adaptive/learner/learner1D.py
Expand Up @@ -10,65 +10,13 @@
import sortedcontainers
import sortedcollections

from adaptive.learner.base_learner import BaseLearner
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
from adaptive.learner.learnerND import volume
from adaptive.learner.triangulation import simplex_volume_in_embedding
from adaptive.notebook_integration import ensure_holoviews
from adaptive.utils import cache_latest


def uses_nth_neighbors(n):
"""Decorator to specify how many neighboring intervals the loss function uses.

Wraps loss functions to indicate that they expect intervals together
with ``n`` nearest neighbors

The loss function will then receive the data of the N nearest neighbors
(``nth_neighbors``) aling with the data of the interval itself in a dict.
The `~adaptive.Learner1D` will also make sure that the loss is updated
whenever one of the ``nth_neighbors`` changes.

Examples
--------

The next function is a part of the `curvature_loss_function` function.

>>> @uses_nth_neighbors(1)
... def triangle_loss(xs, ys):
... xs = [x for x in xs if x is not None]
... ys = [y for y in ys if y is not None]
...
... if len(xs) == 2: # we do not have enough points for a triangle
... return xs[1] - xs[0]
...
... N = len(xs) - 2 # number of constructed triangles
... if isinstance(ys[0], Iterable):
... pts = [(x, *y) for x, y in zip(xs, ys)]
... vol = simplex_volume_in_embedding
... else:
... pts = [(x, y) for x, y in zip(xs, ys)]
... vol = volume
... return sum(vol(pts[i:i+3]) for i in range(N)) / N

Or you may define a loss that favours the (local) minima of a function,
assuming that you know your function will have a single float as output.

>>> @uses_nth_neighbors(1)
... def local_minima_resolving_loss(xs, ys):
... dx = xs[2] - xs[1] # the width of the interval of interest
...
... if not ((ys[0] is not None and ys[0] > ys[1])
... or (ys[3] is not None and ys[3] > ys[2])):
... return loss * 100
...
... return loss
"""
def _wrapped(loss_per_interval):
loss_per_interval.nth_neighbors = n
return loss_per_interval
return _wrapped


@uses_nth_neighbors(0)
def uniform_loss(xs, ys):
"""Loss function that samples the domain uniformly.
Expand Down
150 changes: 133 additions & 17 deletions adaptive/learner/learnerND.py
Expand Up @@ -12,14 +12,20 @@
import scipy.spatial
from sortedcontainers import SortedKeyList

from adaptive.learner.base_learner import BaseLearner
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
from adaptive.learner.triangulation import (
Triangulation, point_in_simplex, circumsphere,
simplex_volume_in_embedding, fast_det)
from adaptive.utils import restore, cache_latest


def to_list(inp):
if isinstance(inp, Iterable):
return list(inp)
return [inp]


def volume(simplex, ys=None):
# Notice the parameter ys is there so you can use this volume method as
# as loss function
Expand Down Expand Up @@ -60,6 +66,71 @@ def default_loss(simplex, ys):
return simplex_volume_in_embedding(pts)


@uses_nth_neighbors(1)
def triangle_loss(simplex, values, neighbors, neighbor_values):
"""
Computes the average of the volumes of the simplex combined with each
neighbouring point.

Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The function values of each of the simplex points.
neighbors : list of tuples
The neighboring points of the simplex, ordered such that simplex[0]
exacly opposes neighbors[0], etc.
neighbor_values : list of values
The function values for each of the neighboring points.

Returns
-------
loss : float
"""

neighbors = [n for n in neighbors if n is not None]
neighbor_values = [v for v in neighbor_values if v is not None]
if len(neighbors) == 0:
return 0

s = [(*x, *to_list(y)) for x, y in zip(simplex, values)]
n = [(*x, *to_list(y)) for x, y in zip(neighbors, neighbor_values)]

return sum(simplex_volume_in_embedding([*s, neighbor])
for neighbor in n) / len(neighbors)


def curvature_loss_function(exploration=0.05):
# XXX: add doc-string!
@uses_nth_neighbors(1)
def curvature_loss(simplex, values, neighbors, neighbor_values):
"""Compute the curvature loss of a simplex.

Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex.
values : list of values
The function values of each of the simplex points.
neighbors : list of tuples
The neighboring points of the simplex, ordered such that simplex[0]
exacly opposes neighbors[0], etc.
neighbor_values : list of values
The function values for each of the neighboring points.

Returns
-------
loss : float
"""
dim = len(simplex[0]) # the number of coordinates
loss_input_volume = volume(simplex)

loss_curvature = triangle_loss(simplex, values, neighbors, neighbor_values)
return (loss_curvature + exploration * loss_input_volume ** ((2 + dim) / dim)) ** (1 / (2 + dim))
return curvature_loss


def choose_point_in_simplex(simplex, transform=None):
"""Choose a new point in inside a simplex.

Expand All @@ -70,9 +141,10 @@ def choose_point_in_simplex(simplex, transform=None):
Parameters
----------
simplex : numpy array
The coordinates of a triangle with shape (N+1, N)
The coordinates of a triangle with shape (N+1, N).
transform : N*N matrix
The multiplication to apply to the simplex before choosing the new point
The multiplication to apply to the simplex before choosing
the new point.

Returns
-------
Expand Down Expand Up @@ -164,6 +236,17 @@ class LearnerND(BaseLearner):
def __init__(self, func, bounds, loss_per_simplex=None):
self._vdim = None
self.loss_per_simplex = loss_per_simplex or default_loss

if hasattr(self.loss_per_simplex, 'nth_neighbors'):
if self.loss_per_simplex.nth_neighbors > 1:
raise NotImplementedError('The provided loss function wants '
'next-nearest neighboring simplices for the loss computation, '
'this feature is not yet implemented, either use '
'nth_neightbors = 0 or 1')
self.nth_neighbors = self.loss_per_simplex.nth_neighbors
else:
self.nth_neighbors = 0

self.data = OrderedDict()
self.pending_points = set()

Expand Down Expand Up @@ -252,14 +335,15 @@ def tri(self):

try:
self._tri = Triangulation(self.points)
self._update_losses(set(), self._tri.simplices)
return self._tri
except ValueError:
# A ValueError is raised if we do not have enough points or
# the provided points are coplanar, so we need more points to
# create a valid triangulation
return None

self._update_losses(set(), self._tri.simplices)
return self._tri

@property
def values(self):
"""Get the values from `data` as a numpy array."""
Expand Down Expand Up @@ -326,10 +410,10 @@ def tell_pending(self, point, *, simplex=None):

simplex = tuple(simplex)
simplices = [self.tri.vertex_to_simplices[i] for i in simplex]
neighbours = set.union(*simplices)
neighbors = set.union(*simplices)
# Neighbours also includes the simplex itself

for simpl in neighbours:
for simpl in neighbors:
_, to_add = self._try_adding_pending_point_to_simplex(point, simpl)
if to_add is None:
continue
Expand Down Expand Up @@ -394,6 +478,7 @@ def _pop_highest_existing_simplex(self):
# find the simplex with the highest loss, we do need to check that the
# simplex hasn't been deleted yet
while len(self._simplex_queue):
# XXX: Need to add check that the loss is the most recent computed loss
loss, simplex, subsimplex = self._simplex_queue.pop(0)
if (subsimplex is None
and simplex in self.tri.simplices
Expand Down Expand Up @@ -449,6 +534,35 @@ def _ask(self):

return self._ask_best_point() # O(log N)

def _compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]

# scale them to a cube with sides 1
vertices = vertices @ self._transform
values = self._output_multiplier * np.array(values)

if self.nth_neighbors == 0:
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))

# We do need the neighbors
neighbors = self.tri.get_opposing_vertices(simplex)

neighbor_points = self.tri.get_vertices(neighbors)
neighbor_values = [self.data.get(x, None) for x in neighbor_points]

for i, point in enumerate(neighbor_points):
if point is not None:
neighbor_points[i] = point @ self._transform

for i, value in enumerate(neighbor_values):
if value is not None:
neighbor_values[i] = self._output_multiplier * value

return float(self.loss_per_simplex(vertices, values, neighbor_points, neighbor_values))

def _update_losses(self, to_delete: set, to_add: set):
# XXX: add the points outside the triangulation to this as well
pending_points_unbound = set()
Expand All @@ -461,7 +575,6 @@ def _update_losses(self, to_delete: set, to_add: set):

pending_points_unbound = set(p for p in pending_points_unbound
if p not in self.data)

for simplex in to_add:
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
Expand All @@ -476,17 +589,20 @@ def _update_losses(self, to_delete: set, to_add: set):
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)

def _compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
if self.nth_neighbors:
points_of_added_simplices = set.union(*[set(s) for s in to_add])
neighbors = self.tri.get_simplices_attached_to_points(
points_of_added_simplices) - to_add
for simplex in neighbors:
loss = self._compute_loss(simplex)
self._losses[simplex] = loss

# scale them to a cube with sides 1
vertices = vertices @ self._transform
values = self._output_multiplier * np.array(values)
if simplex not in self._subtriangulations:
self._simplex_queue.add((loss, simplex, None))
continue

# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)

def _recompute_all_losses(self):
"""Recompute all losses and pending losses."""
Expand Down