Skip to content

Commit

Permalink
docs: Show expected form of initial values
Browse files Browse the repository at this point in the history
Document the type of initial values. Adds a simple
example.

Closes #261
  • Loading branch information
riddell-stan committed Apr 4, 2021
1 parent 29b5cc7 commit 5236d81
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
6 changes: 3 additions & 3 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ API Reference
.. automodule:: stan
:members: build

.. automodule:: stan.fit
:members: Fit

.. automodule:: stan.model
:members: Model

.. automodule:: stan.fit
:members: Fit

.. automodule:: stan.plugins
:members: PluginBase
13 changes: 12 additions & 1 deletion stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
import json
import re
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import httpstan.models
import httpstan.schemas
Expand Down Expand Up @@ -70,6 +70,16 @@ def sample(self, *, num_chains=4, **kwargs) -> stan.fit.Fit:
Returns:
Fit: instance of Fit allowing access to draws.
Examples:
User-defined initial values for parameters must be provided
for each chain. Typically they will be the same for each chain.
The following example shows how user-defined initial parameters
are provided:
>>> program_code = "parameters {real y;} model {y ~ normal(0,1);}"
>>> posterior = stan.build(program_code)
>>> fit = posterior.sample(num_chains=2, init=[{"y": 3}, {"y": 3}])
"""
return self.hmc_nuts_diag_e_adapt(num_chains=num_chains, **kwargs)

Expand Down Expand Up @@ -134,6 +144,7 @@ def _create_fit(self, payload: dict) -> stan.fit.Fit:
payload = json.loads(DataJSONEncoder().encode(payload))
num_chains = payload.pop("num_chains")

init: List[Data]
init = payload.pop("init", [dict() for _ in range(num_chains)])
if len(init) != num_chains:
raise ValueError("Initial values must be provided for each chain.")
Expand Down

0 comments on commit 5236d81

Please sign in to comment.