Skip to content

Commit

Permalink
refactor: Allow arbitrary requests to create fit endpoint
Browse files Browse the repository at this point in the history
Allow arbitrary requests to httpstan's create fit endpoint.
This is needed in order to simplify adding a `fixed_param`
method. Previously the code making the request to the
create fit endpoint was tied into the `sample` method.

Docstrings have been lightly edited. An ``hmc_nuts_diag_e_adapt`` method
has been added, making explicit what was implicit before.

This commit makes no substantive changes to the code.
  • Loading branch information
riddell-stan committed Mar 6, 2021
1 parent 30e4898 commit 60acff2
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Model:

model_name: str
program_code: str
data: dict
data: Data
param_names: Tuple[str, ...]
constrained_param_names: Tuple[str, ...]
dims: Tuple[Tuple[int, ...]]
Expand All @@ -50,33 +50,68 @@ def __post_init__(self):
def sample(self, **kwargs) -> stan.fit.Fit:
"""Draw samples from the model.
Parameters in ``kwargs`` will be passed to the default sample function in
stan::services. Parameter names are identical to those used in CmdStan.
See the CmdStan documentation for parameter descriptions and default
values.
Parameters in ``kwargs`` will be passed to the default sample function
in stan::services. The default sample function is currently
``hmc_nuts_diag_e_adapt``. Parameter names are identical to those used
in CmdStan. See the CmdStan documentation for parameter descriptions
and default values.
`num_chains` is the lone PyStan-specific keyword argument. It indicates
the number of independent processes to use when drawing samples.
The default value is 4.
There is one exception: `num_chains`. `num_chains` is a
PyStan-specific keyword argument. It indicates the number of
independent processes to use when drawing samples. The default value
is 4.
Returns:
Fit: instance of Fit allowing access to draws.
"""
assert "chain" not in kwargs, "`chain` id is set automatically."
assert "data" not in kwargs, "`data` is set in `build`."
assert "random_seed" not in kwargs, "`random_seed` is set in `build`."
return self.hmc_nuts_diag_e_adapt(**kwargs)

num_chains = kwargs.pop("num_chains", 4)
def hmc_nuts_diag_e_adapt(self, **kwargs) -> stan.fit.Fit:
"""Draw samples from the model using ``stan::services::hmc_nuts_diag_e_adapt``.
init = kwargs.pop("init", [dict() for _ in range(num_chains)])
Parameters in ``kwargs`` will be passed to the (Python wrapper of)
``stan::services::hmc_nuts_diag_e_adapt``. Parameter names are
identical to those used in CmdStan. See the CmdStan documentation for
parameter descriptions and default values.
There is one exception: `num_chains`. `num_chains` is a
PyStan-specific keyword argument. It indicates the number of
independent processes to use when drawing samples. The default value
is 4.
Returns:
Fit: instance of Fit allowing access to draws.
"""
kwargs["function"] = "stan::services::sample::hmc_nuts_diag_e_adapt"
return self._create_fit(kwargs)

def _create_fit(self, payload: dict) -> stan.fit.Fit:
"""Make a request to httpstan's ``create_fit`` endpoint and process results.
Users should not use this function.
Arguments:
payload: dict whose JSON-encoded contents will be sent as the request body.
Returns:
Fit: instance of Fit allowing access to draws.
"""
assert "chain" not in payload, "`chain` id is set automatically."
assert "data" not in payload, "`data` is set in `build`."
assert "random_seed" not in payload, "`random_seed` is set in `build`."
assert "function" in payload

num_chains = payload.pop("num_chains", 4)

init = payload.pop("init", [dict() for _ in range(num_chains)])
if len(init) != num_chains:
raise ValueError("Initial values must be provided for each chain.")

payloads = []
for chain in range(1, num_chains + 1):
payload = {"function": "stan::services::sample::hmc_nuts_diag_e_adapt"}
payload.update(kwargs)
payload["chain"] = chain # type: ignore
payload["data"] = self.data # type: ignore
payload["init"] = init.pop(0)
Expand Down

0 comments on commit 60acff2

Please sign in to comment.