Skip to content

Commit

Permalink
feat: Add plugin system
Browse files Browse the repository at this point in the history
Other developers can change the behavior of PyStan by writing a plugin.
Plugin developers should create a class which subclasses
`stan.plugins.PluginBase`. This class must be referenced in their
package's entry points section.

This feature is added with a particular use in mind. A developer now has
the ability run HMC diagnostics when sampling has completed and alert
the user if any problems are detected.

Closes #129
  • Loading branch information
riddell-stan committed Feb 23, 2021
1 parent eabb3c8 commit 83a8877
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ Documentation
installation
upgrading
reference
plugins
contributing
60 changes: 60 additions & 0 deletions doc/plugins.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
=========
Plugins
=========

This is a guide to installing and creating plugins for PyStan.

Installing Plugins
==================

In order to use a plugin, you need to install it. Plugins are published on PyPI and can be installed with ``pip``.

Plugins are automatically enabled as soon as they are installed.

Creating Plugins
================

Plugin developers should create a class which subclasses :py:class:`stan.plugins.PluginBase`. This
class must be referenced in their package's entry points section.

For example, if the class is ``mymodule.PrintParameterNames`` then the
setuptools configuration would look like the following::

entry_points = {
"stan.plugins": [
"names = mymodule:PrintParameterNames"
]
}

The equivalent configuration in poetry would be::

[tool.poetry.plugins."stan.plugins"]
names = mymodule:PrintParameterNames

You can define multiple plugins in the entry points section. Note that the
plugin name (here, `names`) is required but is unused.

All :py:class:`stan.plugins.PluginBase` subclasses implement methods which define behavior associated with *events*.
Currently, there is only one event supported, ``post_fit``.

on_post_fit
-----------

This method defines what happens when sampling has finished and a
:py:class:`stan.fit.Fit` object is about to be returned to the user. The
method takes a :py:class:`stan.fit.Fit` instance as an argument. The method
returns the instance. In a plugin, this method will typically analyze the data contained in
the instance. A plugin might also use this method to modify the instance, adding an
additional method or changing the behavior or an existing method.

**Arguments:**

- ``fit``: :py:class:`stan.fit.Fit` instance

For example, if you wanted to print the names of parameters you would define a plugin as follows::

class PrintParameterNames(stan.plugins.PluginBase):
def on_post_fit(self, fit):
for key in fit:
print(key)
return fit
3 changes: 3 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ API Reference

.. automodule:: stan.model
:members: Model

.. automodule:: stan.plugins
:members: PluginBase
8 changes: 7 additions & 1 deletion stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import stan.common
import stan.fit
import stan.plugins


def _make_json_serializable(data: dict) -> dict:
Expand Down Expand Up @@ -223,7 +224,7 @@ def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
progress_bar.finish()
io.error_line("\n<info>Done.</info>")

return stan.fit.Fit(
fit = stan.fit.Fit(
stan_outputs,
num_chains,
self.param_names,
Expand All @@ -235,6 +236,11 @@ def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
save_warmup,
)

for entry_point in stan.plugins.get_plugins():
Plugin = entry_point.load()
fit = Plugin().on_post_fit(fit)
return fit

try:
return asyncio.run(go())
except KeyboardInterrupt:
Expand Down
44 changes: 44 additions & 0 deletions stan/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import abc
from typing import Generator

import pkg_resources

import stan.fit


def get_plugins() -> Generator[pkg_resources.EntryPoint, None, None]:
"""Iterate over available plugins."""
return pkg_resources.iter_entry_points(group="stan.plugins")


class PluginBase(abc.ABC):
"""Base class for PyStan plugins.
Plugin developers should create a class which subclasses `PluginBase`.
This class must be referenced in their package's entry points section.
"""

# Implementation note: this plugin system is simple because there are only
# a couple of places a plugin developer might want to change behavior. For
# a more full-featured plugin system, see Stevedore
# (<https://docs.openstack.org/stevedore>). This plugin system follows
# (approximately) the pattern stevedore labels `ExtensionManager`.

def on_post_fit(self, fit: stan.fit.Fit) -> stan.fit.Fit:
"""Called with Fit instance when sampling has finished.
The plugin can report information about the samples
contained in the Fit object. It may also add to or
modify the Fit instance.
If the plugin only analyzes the contents of `fit`,
it must return the `fit`.
Argument:
fit: Fit instance.
Returns:
The Fit instance.
"""
return fit
78 changes: 78 additions & 0 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pkg_resources
import pytest

import stan
import stan.plugins

program_code = "parameters {real y;} model {y ~ normal(0,1);}"


class DummyPlugin(stan.plugins.PluginBase):
def on_post_fit(self, fit):
"""Do nothing other than print a string."""
print("In DummyPlugin `on_post_fit`.")
return fit


class MockEntryPoint:
@staticmethod
def load():
return DummyPlugin


def mock_iter_entry_points(group):
return iter([MockEntryPoint])


@pytest.fixture(scope="module")
def normal_posterior():
return stan.build(program_code)


def test_get_plugins(monkeypatch):

monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points)

entry_points = stan.plugins.get_plugins()
Plugin = next(entry_points).load()
assert isinstance(Plugin(), stan.plugins.PluginBase)


def test_dummy_plugin(monkeypatch, capsys, normal_posterior):

monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points)

fit = normal_posterior.sample(stepsize=0.001)
assert fit is not None and "y" in fit

captured = capsys.readouterr()
assert "In DummyPlugin" in captured.out


class OtherDummyPlugin(stan.plugins.PluginBase):
def on_post_fit(self, fit):
"""Do nothing other than print a string."""
print("In OtherDummyPlugin `on_post_fit`.")
return fit


class OtherMockEntryPoint:
@staticmethod
def load():
return OtherDummyPlugin


def test_two_plugins(monkeypatch, capsys, normal_posterior):
"""Make sure that both plugins are used."""

def mock_iter_entry_points(group):
return iter([MockEntryPoint, OtherMockEntryPoint])

monkeypatch.setattr(pkg_resources, "iter_entry_points", mock_iter_entry_points)

fit = normal_posterior.sample(stepsize=0.001)
assert fit is not None and "y" in fit

captured = capsys.readouterr()
assert "In DummyPlugin" in captured.out
assert "In OtherDummyPlugin" in captured.out

0 comments on commit 83a8877

Please sign in to comment.