Skip to content

Commit

Permalink
fix: Improve typing in model.py
Browse files Browse the repository at this point in the history
Improve typing in model.py. In particular, make clear
what our current expectations about the `data` dictionary are.

`data` must be a Python dictionary with values of type
`Union[int, float, Sequence[Union[int, float]], np.ndarray]`.

Nested sequences of Union[int, float] should work. Python's
`typing` library does not currently allow one to describe recursive
types.

Closes #225
  • Loading branch information
riddell-stan committed Feb 28, 2021
1 parent b78eac8 commit 5a09a6b
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions stan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import json
import re
from typing import List, Optional, Tuple
from typing import Dict, Optional, Sequence, Tuple, Union

import httpstan.models
import httpstan.schemas
Expand All @@ -18,8 +18,10 @@
import stan.fit
import stan.plugins

Data = Dict[str, Union[int, float, Sequence[Union[int, float]], np.ndarray]]

def _make_json_serializable(data: dict) -> dict:

def _make_json_serializable(data: Data) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON-serializable form.
Returns a new dictionary.
Expand Down Expand Up @@ -245,8 +247,8 @@ def is_iteration_or_elapsed_time_logger_message(msg: simdjson.Object):
return # type: ignore

def constrain_pars(
self, unconstrained_parameters: List[float], include_tparams: bool = True, include_gqs: bool = True
) -> List[float]:
self, unconstrained_parameters: Sequence[float], include_tparams: bool = True, include_gqs: bool = True
) -> Sequence[float]:
"""Transform a sequence of unconstrained parameters to their defined support,
optionally including transformed parameters and generated quantities.
Expand Down Expand Up @@ -278,7 +280,7 @@ async def go():

return asyncio.run(go())

def unconstrain_pars(self, constrained_parameters: List[float]) -> List[float]:
def unconstrain_pars(self, constrained_parameters: Sequence[float]) -> Sequence[float]:
"""Reads constrained parameter values from their specified context and returns a
sequence of unconstrained parameter values.
Expand All @@ -303,7 +305,7 @@ async def go():

return asyncio.run(go())

def log_prob(self, unconstrained_parameters: List[float], adjust_transform: bool = True) -> float:
def log_prob(self, unconstrained_parameters: Sequence[float], adjust_transform: bool = True) -> float:
"""Calculate the log probability of a set of unconstrained parameters.
Arguments:
Expand Down Expand Up @@ -333,7 +335,7 @@ async def go():

return asyncio.run(go())

def grad_log_prob(self, unconstrained_parameters: List[float]) -> float:
def grad_log_prob(self, unconstrained_parameters: Sequence[float]) -> float:
"""Calculate the gradient of the log posterior evaluated at
the unconstrained parameters.
Expand Down Expand Up @@ -366,15 +368,15 @@ async def go():
return asyncio.run(go())


def build(program_code, data=None, random_seed=None) -> Model:
def build(program_code: str, data: Optional[Data] = None, random_seed: Optional[int] = None) -> Model:
"""Build (compile) a Stan program.
Arguments:
program_code (str): Stan program code describing a Stan model.
data (dict): A Python dictionary or mapping providing the data for the
program_code: Stan program code describing a Stan model.
data: A Python dictionary or mapping providing the data for the
model. Variable names are the keys and the values are their
associated values. Default is an empty dictionary.
random_seed (int): Random seed, a positive integer for random number
random_seed: Random seed, a positive integer for random number
generation. Used to make sure that results can be reproduced.
Returns:
Expand All @@ -385,10 +387,8 @@ def build(program_code, data=None, random_seed=None) -> Model:
variable names; see the Stan User's Guide for a complete list.
"""
if data is None:
data = {}
# _make_json_serializable returns a new dict, original `data` unchanged
data = _make_json_serializable(data)
data = _make_json_serializable(data) if data is not None else {}
assert all(not isinstance(value, np.ndarray) for value in data.values())

async def go():
Expand Down

0 comments on commit 5a09a6b

Please sign in to comment.