Skip to content

Commit

Permalink
refactor: Improve clarity of _create_fit code
Browse files Browse the repository at this point in the history
Improve clarity of `_create_fit` function code by
using consistent argument names. Change also usefully
separates the kwargs from the `payload` variable.
  • Loading branch information
riddell-stan committed Apr 19, 2021
1 parent eeac63d commit d0f3b07
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def hmc_nuts_diag_e_adapt(self, *, num_chains=4, **kwargs) -> stan.fit.Fit:
Fit: instance of Fit allowing access to draws.
"""
kwargs["num_chains"] = num_chains
kwargs["function"] = "stan::services::sample::hmc_nuts_diag_e_adapt"
return self._create_fit(kwargs)
function = "stan::services::sample::hmc_nuts_diag_e_adapt"
return self._create_fit(function=function, num_chains=num_chains, **kwargs)

def fixed_param(self, *, num_chains=4, **kwargs) -> stan.fit.Fit:
"""Draw samples from the model using ``stan::services::sample::fixed_param``.
Expand All @@ -119,38 +118,39 @@ def fixed_param(self, *, num_chains=4, **kwargs) -> stan.fit.Fit:
Fit: instance of Fit allowing access to draws.
"""
kwargs["num_chains"] = num_chains
kwargs["function"] = "stan::services::sample::fixed_param"
return self._create_fit(kwargs)
function = "stan::services::sample::fixed_param"
return self._create_fit(function=function, num_chains=num_chains, **kwargs)

def _create_fit(self, payload: dict) -> stan.fit.Fit:
def _create_fit(self, *, function, num_chains, **kwargs) -> stan.fit.Fit:
"""Make a request to httpstan's ``create_fit`` endpoint and process results.
Users should not use this function.
Arguments:
payload: dict whose JSON-encoded contents will be sent as the request body.
Parameters in ``kwargs`` will be passed to the (Python wrapper of)
`function`. Parameter names are identical to those used in CmdStan.
See the CmdStan documentation for parameter descriptions and default
values.
Returns:
Fit: instance of Fit allowing access to draws.
"""
assert "chain" not in payload, "`chain` id is set automatically."
assert "data" not in payload, "`data` is set in `build`."
assert "random_seed" not in payload, "`random_seed` is set in `build`."
assert "function" in payload
assert "chain" not in kwargs, "`chain` id is set automatically."
assert "data" not in kwargs, "`data` is set in `build`."
assert "random_seed" not in kwargs, "`random_seed` is set in `build`."

# copy kwargs and verify everything is JSON-encodable
payload = json.loads(DataJSONEncoder().encode(payload))
num_chains = payload.pop("num_chains")
kwargs = json.loads(DataJSONEncoder().encode(kwargs))

init: List[Data]
init = payload.pop("init", [dict() for _ in range(num_chains)])
# FIXME: special handling here for `init`, consistent with PyStan 2 but needs docs
init: List[Data] = kwargs.pop("init", [dict() for _ in range(num_chains)])
if len(init) != num_chains:
raise ValueError("Initial values must be provided for each chain.")

payloads = []
for chain in range(1, num_chains + 1):
payload = kwargs.copy()
payload["function"] = function
payload["chain"] = chain # type: ignore
payload["data"] = self.data # type: ignore
payload["init"] = init.pop(0)
Expand All @@ -169,7 +169,7 @@ def _create_fit(self, payload: dict) -> stan.fit.Fit:
"save_warmup",
arguments.lookup_default(arguments.Method["SAMPLE"], "save_warmup"),
)
payloads.append(payload.copy())
payloads.append(payload)

async def go():
io = ConsoleIO()
Expand Down

0 comments on commit d0f3b07

Please sign in to comment.