Skip to content

Commit

Permalink
feat: Fit implements collections.abc.Mapping
Browse files Browse the repository at this point in the history
The Fit class now implements collections.abc.Mapping. Users can iterate
over a Fit instance and use normal dictionary methods.

Technically a breaking change since it gets rid of the `.values` property.

Closes #142
  • Loading branch information
riddell-stan committed Jan 19, 2021
1 parent 28cc94d commit 235d759
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 32 deletions.
27 changes: 13 additions & 14 deletions stan/fit.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import collections
import json
from typing import Tuple, cast
from typing import Generator, Tuple, cast

import numpy as np
import simdjson


class Fit:
class Fit(collections.abc.Mapping):
"""Stores draws from one or more chains.
The ``values`` attribute provides direct access to draws. More user-friendly
presentations of draws are available via the ``to_frame`` and ``to_xarray``
methods.
Returned by methods of a ``Model``. Users will not instantiate this class directly.
Attributes:
values: An ndarray with shape (num_sample_and_sampler_params + num_flat_params, num_draws, num_chains)
A `Fit` instance works like a Python dictionary. Other user-friendly views of draws
are available via the ``to_frame`` and ``to_xarray`` methods.
"""

Expand Down Expand Up @@ -59,6 +56,8 @@ def __init__(
# self._draws holds all the draws. We cannot allocate it before looking at the draws
# because we do not know how many sampler-specific parameters are present. Later in this
# function we count them and only then allocate the array for `self._draws`.
#
# _draws is an ndarray with shape (num_sample_and_sampler_params + num_flat_params, num_draws, num_chains)
self._draws: np.ndarray

parser = simdjson.Parser()
Expand Down Expand Up @@ -122,10 +121,6 @@ def to_frame(self):
df.index.name, df.columns.name = "draws", "parameters"
return df

@property
def values(self) -> np.ndarray:
return self._draws

@property
def _finished(self) -> bool:
return not self._draws.flags["WRITEABLE"]
Expand All @@ -139,16 +134,20 @@ def __getitem__(self, param):
param_dim = [] if param in self.sample_and_sampler_param_names else self.dims[self.param_names.index(param)]
# fmt: off
num_samples_saved = (self.num_samples + self.num_warmup * self.save_warmup) // self.num_thin
assert self.values.shape == (len(self.sample_and_sampler_param_names) + len(self.constrained_param_names), num_samples_saved, self.num_chains)
assert self._draws.shape == (len(self.sample_and_sampler_param_names) + len(self.constrained_param_names), num_samples_saved, self.num_chains)
# fmt: on
# Stack chains together. Parameter is still stored flat.
view = self.values[param_indexes, :, :].reshape(len(param_indexes), -1).view()
view = self._draws[param_indexes, :, :].reshape(len(param_indexes), -1).view()
assert view.shape == (len(param_indexes), num_samples_saved * self.num_chains)
# reshape must yield something with least two dimensions
reshape_args = param_dim + [-1] if param_dim else (1, -1)
# reshape, recover the shape of the stan parameter
return view.reshape(*reshape_args, order="F")

def __iter__(self) -> Generator[str, None, None]:
for name in self.param_names:
yield name

def __len__(self) -> int:
return len(self.param_names)

Expand Down
26 changes: 13 additions & 13 deletions tests/test_basic_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def fit(posterior):

def test_bernoulli_sampling_thin(posterior):
fit = posterior.sample(num_thin=2)
assert fit.values.shape[1] == 500
assert fit["theta"].shape[-1] == 500


def test_bernoulli_sampling_invalid_argument(posterior):
Expand All @@ -45,22 +45,22 @@ def test_bernoulli_sampling(fit):
assert fit.param_names == ("theta",)
assert fit.num_chains == 4

assert fit.values.ndim == 3
assert fit.values.shape[1] == 1000
assert fit.values.shape[2] == 4
assert fit._draws.ndim == 3
assert fit._draws.shape[1] == 1000
assert fit._draws.shape[2] == 4

assert len(fit) == 1 # one parameter (theta)

# for a fit with only one scalar parameter, it is the last one
assert 0.1 < fit.values[-1, :, 0].mean() < 0.4
assert 0.1 < fit.values[-1, :, 1].mean() < 0.4
assert 0.1 < fit.values[-1, :, 2].mean() < 0.4
assert 0.1 < fit.values[-1, :, 3].mean() < 0.4

assert 0.01 < fit.values[-1, :, 0].var() < 0.02
assert 0.01 < fit.values[-1, :, 1].var() < 0.02
assert 0.01 < fit.values[-1, :, 2].var() < 0.02
assert 0.01 < fit.values[-1, :, 3].var() < 0.02
assert 0.1 < fit._draws[-1, :, 0].mean() < 0.4
assert 0.1 < fit._draws[-1, :, 1].mean() < 0.4
assert 0.1 < fit._draws[-1, :, 2].mean() < 0.4
assert 0.1 < fit._draws[-1, :, 3].mean() < 0.4

assert 0.01 < fit._draws[-1, :, 0].var() < 0.02
assert 0.01 < fit._draws[-1, :, 1].var() < 0.02
assert 0.01 < fit._draws[-1, :, 2].var() < 0.02
assert 0.01 < fit._draws[-1, :, 3].var() < 0.02


def test_bernoulli_to_frame(fit):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_eight_schools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_eight_schools_sample(posterior):
num_chains, num_samples = 2, 200
fit = posterior.sample(num_chains=num_chains, num_samples=num_samples, num_warmup=num_samples)
num_flat_params = schools_data["J"] * 2 + 2
assert fit.values.shape == (
assert fit._draws.shape == (
len(fit.sample_and_sampler_param_names) + num_flat_params,
num_samples,
num_chains,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_eight_schools_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_eight_schools_large_sample(posterior):
num_chains, num_samples = 2, 200
fit = posterior.sample(num_chains=num_chains, num_samples=num_samples, num_warmup=num_samples)
num_flat_params = schools_data["J"] * 2 + 2
assert fit.values.shape == (
assert fit._draws.shape == (
len(fit.sample_and_sampler_param_names) + num_flat_params,
num_samples,
num_chains,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_fit_basic_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ def test_fit_scalar_param(fit):
y = fit["y"]
assert y.shape == (1, num_samples * num_chains)
assert 9 < np.mean(y) < 11


def test_fit_mapping(fit):
# test Fit's `dict`-like functionality
params = [param for param in fit]
assert params == ["y"]
assert params == list(fit.keys())
assert fit["y"].mean() == list(fit.values()).pop().mean()
key, value = list(fit.items()).pop()
assert key == "y"
assert value.mean() == fit["y"].mean()
6 changes: 3 additions & 3 deletions tests/test_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_normal_sample():
assert posterior is not None
fit = posterior.sample()
offset = len(fit.sample_and_sampler_param_names)
assert fit.values.shape == (offset + 1, 1000, 1) # 1 chain, n samples, 1 param
assert fit._draws.shape == (offset + 1, 1000, 1) # 1 chain, n samples, 1 param
df = fit.to_frame()
assert (df["y"] == fit.values[offset, :, :].ravel()).all()
assert (df["y"] == fit._draws[offset, :, :].ravel()).all()
assert len(df["y"]) == 1000
assert -0.01 < df["y"].mean() < 0.01
assert -0.01 < df["y"].std() < 0.01
Expand All @@ -30,7 +30,7 @@ def test_normal_sample_chains():
assert posterior is not None
fit = posterior.sample(num_chains=3)
offset = len(fit.sample_and_sampler_param_names)
assert fit.values.shape == (offset + 1, 1000, 3) # 1 param, n samples, 3 chains
assert fit._draws.shape == (offset + 1, 1000, 3) # 1 param, n samples, 3 chains
df = fit.to_frame()
assert len(df["y"]) == 3000
assert -5 < df["y"].mean() < 5
Expand Down

0 comments on commit 235d759

Please sign in to comment.