-
-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Other developers can change the behavior of PyStan by writing a plugin. Plugin developers should create a class which subclasses `stan.plugins.PluginBase`. This class must be referenced in their package's entry points section. This feature is added with a particular use in mind. A developer now has the ability run HMC diagnostics when sampling has completed and alert the user if any problems are detected. Closes #129
- Loading branch information
1 parent
eabb3c8
commit 83a8877
Showing
6 changed files
with
193 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,4 +79,5 @@ Documentation | |
installation | ||
upgrading | ||
reference | ||
plugins | ||
contributing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
========= | ||
Plugins | ||
========= | ||
|
||
This is a guide to installing and creating plugins for PyStan. | ||
|
||
Installing Plugins | ||
================== | ||
|
||
In order to use a plugin, you need to install it. Plugins are published on PyPI and can be installed with ``pip``. | ||
|
||
Plugins are automatically enabled as soon as they are installed. | ||
|
||
Creating Plugins | ||
================ | ||
|
||
Plugin developers should create a class which subclasses :py:class:`stan.plugins.PluginBase`. This | ||
class must be referenced in their package's entry points section. | ||
|
||
For example, if the class is ``mymodule.PrintParameterNames`` then the | ||
setuptools configuration would look like the following:: | ||
|
||
entry_points = { | ||
"stan.plugins": [ | ||
"names = mymodule:PrintParameterNames" | ||
] | ||
} | ||
|
||
The equivalent configuration in poetry would be:: | ||
|
||
[tool.poetry.plugins."stan.plugins"] | ||
names = mymodule:PrintParameterNames | ||
|
||
You can define multiple plugins in the entry points section. Note that the | ||
plugin name (here, `names`) is required but is unused. | ||
|
||
All :py:class:`stan.plugins.PluginBase` subclasses implement methods which define behavior associated with *events*. | ||
Currently, there is only one event supported, ``post_fit``. | ||
|
||
on_post_fit | ||
----------- | ||
|
||
This method defines what happens when sampling has finished and a | ||
:py:class:`stan.fit.Fit` object is about to be returned to the user. The | ||
method takes a :py:class:`stan.fit.Fit` instance as an argument. The method | ||
returns the instance. In a plugin, this method will typically analyze the data contained in | ||
the instance. A plugin might also use this method to modify the instance, adding an | ||
additional method or changing the behavior or an existing method. | ||
|
||
**Arguments:** | ||
|
||
- ``fit``: :py:class:`stan.fit.Fit` instance | ||
|
||
For example, if you wanted to print the names of parameters you would define a plugin as follows:: | ||
|
||
class PrintParameterNames(stan.plugins.PluginBase): | ||
def on_post_fit(self, fit): | ||
for key in fit: | ||
print(key) | ||
return fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,6 @@ API Reference | |
|
||
.. automodule:: stan.model | ||
:members: Model | ||
|
||
.. automodule:: stan.plugins | ||
:members: PluginBase |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import abc | ||
from typing import Generator | ||
|
||
import pkg_resources | ||
|
||
import stan.fit | ||
|
||
|
||
def get_plugins() -> Generator[pkg_resources.EntryPoint, None, None]: | ||
"""Iterate over available plugins.""" | ||
return pkg_resources.iter_entry_points(group="stan.plugins") | ||
|
||
|
||
class PluginBase(abc.ABC): | ||
"""Base class for PyStan plugins. | ||
Plugin developers should create a class which subclasses `PluginBase`. | ||
This class must be referenced in their package's entry points section. | ||
""" | ||
|
||
# Implementation note: this plugin system is simple because there are only | ||
# a couple of places a plugin developer might want to change behavior. For | ||
# a more full-featured plugin system, see Stevedore | ||
# (<https://docs.openstack.org/stevedore>). This plugin system follows | ||
# (approximately) the pattern stevedore labels `ExtensionManager`. | ||
|
||
def on_post_fit(self, fit: stan.fit.Fit) -> stan.fit.Fit: | ||
"""Called with Fit instance when sampling has finished. | ||
The plugin can report information about the samples | ||
contained in the Fit object. It may also add to or | ||
modify the Fit instance. | ||
If the plugin only analyzes the contents of `fit`, | ||
it must return the `fit`. | ||
Argument: | ||
fit: Fit instance. | ||
Returns: | ||
The Fit instance. | ||
""" | ||
return fit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import pkg_resources | ||
import pytest | ||
|
||
import stan | ||
import stan.plugins | ||
|
||
program_code = "parameters {real y;} model {y ~ normal(0,1);}" | ||
|
||
|
||
class DummyPlugin(stan.plugins.PluginBase): | ||
def on_post_fit(self, fit): | ||
"""Do nothing other than print a string.""" | ||
print("In DummyPlugin `on_post_fit`.") | ||
return fit | ||
|
||
|
||
class MockEntryPoint: | ||
@staticmethod | ||
def load(): | ||
return DummyPlugin | ||
|
||
|
||
def mock_iter_entry_points(group): | ||
return iter([MockEntryPoint]) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def normal_posterior(): | ||
return stan.build(program_code) | ||
|
||
|
||
def test_get_plugins(monkeypatch): | ||
|
||
monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points) | ||
|
||
entry_points = stan.plugins.get_plugins() | ||
Plugin = next(entry_points).load() | ||
assert isinstance(Plugin(), stan.plugins.PluginBase) | ||
|
||
|
||
def test_dummy_plugin(monkeypatch, capsys, normal_posterior): | ||
|
||
monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points) | ||
|
||
fit = normal_posterior.sample(stepsize=0.001) | ||
assert fit is not None and "y" in fit | ||
|
||
captured = capsys.readouterr() | ||
assert "In DummyPlugin" in captured.out | ||
|
||
|
||
class OtherDummyPlugin(stan.plugins.PluginBase): | ||
def on_post_fit(self, fit): | ||
"""Do nothing other than print a string.""" | ||
print("In OtherDummyPlugin `on_post_fit`.") | ||
return fit | ||
|
||
|
||
class OtherMockEntryPoint: | ||
@staticmethod | ||
def load(): | ||
return OtherDummyPlugin | ||
|
||
|
||
def test_two_plugins(monkeypatch, capsys, normal_posterior): | ||
"""Make sure that both plugins are used.""" | ||
|
||
def mock_iter_entry_points(group): | ||
return iter([MockEntryPoint, OtherMockEntryPoint]) | ||
|
||
monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points) | ||
|
||
fit = normal_posterior.sample(stepsize=0.001) | ||
assert fit is not None and "y" in fit | ||
|
||
captured = capsys.readouterr() | ||
assert "In DummyPlugin" in captured.out | ||
assert "In OtherDummyPlugin" in captured.out |