-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* doc fixes and tweaks * rename progress.py -> monitoring.py Also don't expose ProgressBar directly in xsimlab main namespace. * test hook functions * black * update release notes
- Loading branch information
Showing
8 changed files
with
192 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# A hack around ProgressBar (monkey patch) so that it renders | ||
# nicely in docs | ||
import io | ||
|
||
import xsimlab | ||
from xsimlab import runtime_hook | ||
from xsimlab.monitoring import ProgressBar as _ProgressBar | ||
|
||
|
||
class ProgressBarHack(_ProgressBar): | ||
"""Redirects progress bar outputs to a variable, and | ||
only display the rendered string (last line) at the end | ||
the simulation. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super(ProgressBarHack, self).__init__(**kwargs) | ||
|
||
self.pbar_output = io.StringIO() | ||
self.tqdm_kwargs.update({"file": self.pbar_output}) | ||
|
||
@runtime_hook("finalize", trigger="post") | ||
def close_bar(self, model, context, state): | ||
super(ProgressBarHack, self).close_bar(model, context, state) | ||
print(self.pbar_output.getvalue().strip().split("\r")[-1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import importlib | ||
|
||
import pytest | ||
|
||
from ..monitoring import ProgressBar | ||
from . import has_tqdm | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
@pytest.mark.parametrize( | ||
"frontend,tqdm_module", | ||
[ | ||
("auto", "tqdm"), # assume tests are run in a terminal evironment | ||
("console", "tqdm"), | ||
("gui", "tqdm.gui"), | ||
("notebook", "tqdm.notebook"), | ||
], | ||
) | ||
def test_progress_bar_init(frontend, tqdm_module): | ||
pbar = ProgressBar(frontend=frontend) | ||
tqdm = importlib.import_module(tqdm_module) | ||
|
||
assert pbar.tqdm is tqdm.tqdm | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
@pytest.mark.parametrize("kw", [{}, {"bar_format": "{bar}"}]) | ||
def test_progress_bar_init_kwargs(kw): | ||
pbar = ProgressBar(**kw) | ||
|
||
assert "bar_format" in pbar.tqdm_kwargs | ||
|
||
if "bar_format" in kw: | ||
assert pbar.tqdm_kwargs["bar_format"] == kw["bar_format"] | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_init_error(in_dataset, model): | ||
with pytest.raises(ValueError, match=r".*not supported.*"): | ||
ProgressBar(frontend="invalid_frontend") | ||
|
||
|
||
@pytest.mark.parametrize("kw", [{}, {"desc": "custom description"}]) | ||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_init_bar(kw): | ||
pbar = ProgressBar(**kw) | ||
pbar.init_bar(None, {"nsteps": 10}, {}) | ||
|
||
assert pbar.pbar_model.format_dict["total"] == 12 | ||
if kw: | ||
assert pbar.pbar_model.format_dict["prefix"] == "custom description" | ||
else: | ||
assert pbar.pbar_model.format_dict["prefix"] == "initialize" | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_update_init(): | ||
pbar = ProgressBar() | ||
pbar.init_bar(None, {"nsteps": 10}, {}) | ||
pbar.update_init(None, {}, {}) | ||
|
||
assert pbar.pbar_model.format_dict["n"] == 1 | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_update_run_step(): | ||
pbar = ProgressBar() | ||
pbar.init_bar(None, {"nsteps": 10}, {}) | ||
pbar.update_init(None, {}, {}) | ||
pbar.update_run_step(None, {"nsteps": 10, "step": 1}, {}) | ||
|
||
assert pbar.pbar_model.format_dict["n"] == 2 | ||
assert pbar.pbar_model.format_dict["prefix"] == "run step 1/10" | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_update_finalize(): | ||
pbar = ProgressBar() | ||
pbar.init_bar(None, {"nsteps": 10}, {}) | ||
pbar.update_finalize(None, {}, {}) | ||
|
||
assert pbar.pbar_model.format_dict["prefix"] == "finalize" | ||
|
||
|
||
@pytest.mark.skipif(not has_tqdm, reason="requires tqdm") | ||
def test_progress_bar_close_bar(): | ||
pbar = ProgressBar() | ||
pbar.init_bar(None, {"nsteps": 10}, {}) | ||
pbar.close_bar(None, {}, {}) | ||
|
||
assert pbar.pbar_model.format_dict["n"] == 1 | ||
assert pbar.pbar_model.format_dict["prefix"].startswith("Simulation finished") |
This file was deleted.
Oops, something went wrong.