Skip to content

Commit

Permalink
Typehint SequenceLearner (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 11, 2022
1 parent 50fae43 commit 9860573
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from copy import copy
from typing import Any, Tuple

import cloudpickle
from sortedcontainers import SortedDict, SortedSet

from adaptive.learner.base_learner import BaseLearner
from adaptive.types import Int
from adaptive.utils import assign_defaults, partial_function_from_dataframe

try:
Expand All @@ -16,6 +18,14 @@
except ModuleNotFoundError:
with_pandas = False

try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias


PointType: TypeAlias = Tuple[Int, Any]


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
Expand All @@ -30,7 +40,7 @@ class _IgnoreFirstArgument:
def __init__(self, function):
self.function = function

def __call__(self, index_point, *args, **kwargs):
def __call__(self, index_point: PointType, *args, **kwargs):
index, point = index_point
return self.function(point, *args, **kwargs)

Expand Down Expand Up @@ -81,7 +91,9 @@ def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
return SequenceLearner(self._original_function, self.sequence)

def ask(self, n, tell_pending=True):
def ask(
self, n: int, tell_pending: bool = True
) -> tuple[list[PointType], list[float]]:
indices = []
points = []
loss_improvements = []
Expand All @@ -99,40 +111,40 @@ def ask(self, n, tell_pending=True):

return points, loss_improvements

def loss(self, real=True):
def loss(self, real: bool = True) -> float:
if not (self._to_do_indices or self.pending_points):
return 0
return 0.0
else:
npoints = self.npoints + (0 if real else len(self.pending_points))
return (self._ntotal - npoints) / self._ntotal

def remove_unfinished(self):
def remove_unfinished(self) -> None:
for i in self.pending_points:
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point, value):
def tell(self, point: PointType, value: Any) -> None:
index, point = point
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point):
def tell_pending(self, point: PointType) -> None:
index, point = point
self.pending_points.add(index)
self._to_do_indices.discard(index)

def done(self):
def done(self) -> bool:
return not self._to_do_indices and not self.pending_points

def result(self):
def result(self) -> list[Any]:
"""Get the function values in the same order as ``sequence``."""
if not self.done():
raise Exception("Learner is not yet complete.")
return list(self.data.values())

@property
def npoints(self):
def npoints(self) -> int:
return len(self.data)

def to_dataframe(
Expand Down Expand Up @@ -213,16 +225,18 @@ def load_dataframe(
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"
"""
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
indices = df[index_name].values
xs = df[x_name].values
self.tell_many(zip(indices, xs), df[y_name].values)
if with_default_function_args:
self.function = partial_function_from_dataframe(
self._original_function, df, function_prefix
)

def _get_data(self):
def _get_data(self) -> dict[int, Any]:
return self.data

def _set_data(self, data):
def _set_data(self, data: dict[int, Any]) -> None:
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
Expand Down

0 comments on commit 9860573

Please sign in to comment.