Skip to content

Commit

Permalink
fix: Rename plugin hook
Browse files Browse the repository at this point in the history
The plugin hook should always have been named `on_post_sample`.
Naming it `on_post_fit` was a mistake.

Closes #254
  • Loading branch information
riddell-stan committed Mar 29, 2021
1 parent 08dc3da commit e2eb998
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
12 changes: 8 additions & 4 deletions doc/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ You can define multiple plugins in the entry points section. Note that the
plugin name (here, `names`) is required but is unused.

All :py:class:`stan.plugins.PluginBase` subclasses implement methods which define behavior associated with *events*.
Currently, there is only one event supported, ``post_fit``.
There is only one event supported, ``post_sample``.

on_post_fit
-----------
on_post_sample
--------------

This method defines what happens when sampling has finished and a
:py:class:`stan.fit.Fit` object is about to be returned to the user. The
Expand All @@ -56,7 +56,11 @@ additional method or changing the behavior or an existing method.
For example, if you wanted to print the names of parameters you would define a plugin as follows::

class PrintParameterNames(stan.plugins.PluginBase):
def on_post_fit(self, fit):
def on_post_sample(self, fit, **kwargs):
for key in fit:
print(key)
return fit

Note that `on_post_sample` accepts additional keyword arguments (``**kwargs``). Accepting
keyword arguments like this will allow your plugin to be compatible with future versions of the package.
Future versions of the package could, in principle, add additional arguments to `on_post_sample`.
2 changes: 1 addition & 1 deletion stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):

for entry_point in stan.plugins.get_plugins():
Plugin = entry_point.load()
fit = Plugin().on_post_fit(fit)
fit = Plugin().on_post_sample(fit)
return fit

try:
Expand Down
2 changes: 1 addition & 1 deletion stan/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PluginBase(abc.ABC):
# (<https://docs.openstack.org/stevedore>). This plugin system follows
# (approximately) the pattern stevedore labels `ExtensionManager`.

def on_post_fit(self, fit: stan.fit.Fit) -> stan.fit.Fit:
def on_post_sample(self, fit: stan.fit.Fit) -> stan.fit.Fit:
"""Called with Fit instance when sampling has finished.
The plugin can report information about the samples
Expand Down
8 changes: 4 additions & 4 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


class DummyPlugin(stan.plugins.PluginBase):
def on_post_fit(self, fit):
def on_post_sample(self, fit, **kwargs):
"""Do nothing other than print a string."""
print("In DummyPlugin `on_post_fit`.")
print("In DummyPlugin `on_post_sample`.")
return fit


Expand Down Expand Up @@ -50,9 +50,9 @@ def test_dummy_plugin(monkeypatch, capsys, normal_posterior):


class OtherDummyPlugin(stan.plugins.PluginBase):
def on_post_fit(self, fit):
def on_post_sample(self, fit, **kwargs):
"""Do nothing other than print a string."""
print("In OtherDummyPlugin `on_post_fit`.")
print("In OtherDummyPlugin `on_post_sample`.")
return fit


Expand Down

0 comments on commit e2eb998

Please sign in to comment.