Skip to content

Commit

Permalink
Allow storing the full sequence in SequenceLearner.to_dataframe (#425)
Browse files Browse the repository at this point in the history
* Allow storing the full sequence in SequenceLearner.to_dataframe

* Test dataframes

* skip if minimal deps
  • Loading branch information
basnijholt committed May 9, 2023
1 parent 0dd5d98 commit 2b94152
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 7 deletions.
51 changes: 45 additions & 6 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from copy import copy
from typing import Any
from typing import TYPE_CHECKING, Any

import cloudpickle
from sortedcontainers import SortedDict, SortedSet
Expand All @@ -15,6 +15,10 @@
partial_function_from_dataframe,
)

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Callable

try:
import pandas

Expand Down Expand Up @@ -82,12 +86,17 @@ class SequenceLearner(BaseLearner):
the added benefit of having results in the local kernel already.
"""

def __init__(self, function, sequence):
def __init__(
self,
function: Callable[[Any], Any],
sequence: Sequence[Any],
):
self._original_function = function
self.function = _IgnoreFirstArgument(function)
# prefer range(len(...)) over enumerate to avoid slowdowns
# when passing lazy sequences
self._to_do_indices = SortedSet(range(len(sequence)))
indices = range(len(sequence))
self._to_do_indices = SortedSet(indices)
self._ntotal = len(sequence)
self.sequence = copy(sequence)
self.data = SortedDict()
Expand Down Expand Up @@ -161,6 +170,8 @@ def to_dataframe( # type: ignore[override]
index_name: str = "i",
x_name: str = "x",
y_name: str = "y",
*,
full_sequence: bool = False,
) -> pandas.DataFrame:
"""Return the data as a `pandas.DataFrame`.
Expand All @@ -178,6 +189,9 @@ def to_dataframe( # type: ignore[override]
Name of the input value, by default "x"
y_name : str, optional
Name of the output value, by default "y"
full_sequence : bool, optional
If True, the returned dataframe will have the full sequence
where the y_name values are pd.NA if not evaluated yet.
Returns
-------
Expand All @@ -190,8 +204,16 @@ def to_dataframe( # type: ignore[override]
"""
if not with_pandas:
raise ImportError("pandas is not installed.")
indices, ys = zip(*self.data.items()) if self.data else ([], [])
sequence = [self.sequence[i] for i in indices]
import pandas as pd

if full_sequence:
indices = list(range(len(self.sequence)))
sequence = list(self.sequence)
ys = [self.data.get(i, pd.NA) for i in indices]
else:
indices, ys = zip(*self.data.items()) if self.data else ([], []) # type: ignore[assignment]
sequence = [self.sequence[i] for i in indices]

df = pandas.DataFrame(indices, columns=[index_name])
df[x_name] = sequence
df[y_name] = ys
Expand All @@ -209,6 +231,8 @@ def load_dataframe( # type: ignore[override]
index_name: str = "i",
x_name: str = "x",
y_name: str = "y",
*,
full_sequence: bool = False,
):
"""Load data from a `pandas.DataFrame`.
Expand All @@ -231,10 +255,25 @@ def load_dataframe( # type: ignore[override]
The ``x_name`` used in ``to_dataframe``, by default "x"
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"
full_sequence : bool, optional
The ``full_sequence`` used in ``to_dataframe``, by default False
"""
if not with_pandas:
raise ImportError("pandas is not installed.")
import pandas as pd

indices = df[index_name].values
xs = df[x_name].values
self.tell_many(zip(indices, xs), df[y_name].values)
ys = df[y_name].values

if full_sequence:
evaluated_indices = [i for i, y in enumerate(ys) if y is not pd.NA]
xs = xs[evaluated_indices]
ys = ys[evaluated_indices]
indices = indices[evaluated_indices]

self.tell_many(zip(indices, xs), ys)

if with_default_function_args:
self.function = partial_function_from_dataframe(
self._original_function, df, function_prefix
Expand Down
36 changes: 35 additions & 1 deletion adaptive/tests/test_sequence_learner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import asyncio

import pytest

from adaptive import Runner, SequenceLearner
from adaptive.runner import SequentialExecutor
from adaptive.learner.learner1D import with_pandas
from adaptive.runner import SequentialExecutor, simple

offset = 0.0123


def peak(x, offset=offset, wait=True):
a = 0.01
return {"x": x + a**2 / (a**2 + (x - offset) ** 2)}


class FailOnce:
Expand All @@ -22,3 +32,27 @@ def test_fail_with_sequence_of_unhashable():
runner = Runner(learner, retries=1, executor=SequentialExecutor())
asyncio.get_event_loop().run_until_complete(runner.task)
assert runner.status() == "finished"


@pytest.mark.skipif(not with_pandas, reason="pandas is not installed")
def test_save_load_dataframe():
learner = SequenceLearner(peak, sequence=range(10, 30, 1))
simple(learner, npoints_goal=10)
df = learner.to_dataframe()
assert len(df) == 10
assert df["x"].iloc[0] == 10
df_full = learner.to_dataframe(full_sequence=True)
assert len(df_full) == 20
assert df_full["x"].iloc[0] == 10
assert df_full["x"].iloc[-1] == 29

learner2 = learner.new()
assert learner2.data == {}
learner2.load_dataframe(df)
assert len(learner2.data) == 10
assert learner.to_dataframe().equals(df)

learner3 = learner.new()
learner3.load_dataframe(df_full, full_sequence=True)
assert len(learner3.data) == 10
assert learner3.to_dataframe(full_sequence=True).equals(df_full)

0 comments on commit 2b94152

Please sign in to comment.