Skip to content

Commit

Permalink
Allow to periodically save with any function (#362)
Browse files Browse the repository at this point in the history
* Allow to periodically save with pandas

* Generalize saving

* add future annotations

* Fix defaults

* remove typing_extensions again

* rename
  • Loading branch information
basnijholt committed Oct 5, 2022
1 parent 21fb3b6 commit 3148552
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import asyncio
import concurrent.futures as concurrent
Expand All @@ -11,11 +13,15 @@
import traceback
import warnings
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable

import loky

from adaptive.notebook_integration import in_ipynb, live_info, live_plot

if TYPE_CHECKING:
from adaptive import BaseLearner

try:
import ipyparallel

Expand Down Expand Up @@ -663,15 +669,26 @@ def elapsed_time(self):
end_time = time.time()
return end_time - self.start_time

def start_periodic_saving(self, save_kwargs, interval):
def start_periodic_saving(
self,
save_kwargs: dict[str, Any] | None = None,
interval: int = 30,
method: Callable[[BaseLearner], None] | None = None,
):
"""Periodically save the learner's data.
Parameters
----------
save_kwargs : dict
Key-word arguments for ``learner.save(**save_kwargs)``.
Only used if ``method=None``.
interval : int
Number of seconds between saving the learner.
method : callable
The method to use for saving the learner. If None, the default
saves the learner using "pickle" which calls
``learner.save(**save_kwargs)``. Otherwise provide a callable
that takes the learner and saves the learner.
Example
-------
Expand All @@ -681,11 +698,19 @@ def start_periodic_saving(self, save_kwargs, interval):
... interval=600)
"""

async def _saver(save_kwargs=save_kwargs, interval=interval):
def default_save(learner):
learner.save(**save_kwargs)

if method is None:
method = default_save
if save_kwargs is None:
raise ValueError("Must provide `save_kwargs` if method=None.")

async def _saver():
while self.status() == "running":
self.learner.save(**save_kwargs)
method(self.learner)
await asyncio.sleep(interval)
self.learner.save(**save_kwargs) # one last time
method(self.learner) # one last time

self.saving_task = self.ioloop.create_task(_saver())
return self.saving_task
Expand Down

0 comments on commit 3148552

Please sign in to comment.