Skip to content

Commit

Permalink
feat: Add __repr__ and __len__ for Fit
Browse files Browse the repository at this point in the history
Adds a minimally informative __repr__ for the Fit class. Also adds a
simple __len__ which reports the number of parameters.

Closes #41
Closes #92
  • Loading branch information
riddell-stan committed Apr 18, 2020
1 parent 93612c4 commit a7d9290
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 2 deletions.
23 changes: 21 additions & 2 deletions stan/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ def to_frame(self):
return df

@property
def values(self):
def values(self) -> np.ndarray:
return self._draws

@property
def _finished(self):
def _finished(self) -> bool:
return not self._draws.flags["WRITEABLE"]

def __getitem__(self, param):
Expand All @@ -158,6 +158,25 @@ def __getitem__(self, param):
# reshape, recover the shape of the stan parameter
return view.reshape(*reshape_args, order="F")

def __len__(self) -> int:
return len(self.param_names)

def __repr__(self) -> str:
# inspired by xarray
parts = [f"<stan.{type(self).__name__}>"]

def summarize_param(param_name, dims):
return f" {param_name}: {tuple(dims)}"

if self.param_names:
parts.append("Parameters:")
for param_name, dims in zip(self.param_names, self.dims):
parts.append(summarize_param(param_name, dims))
if self._finished:
# total draws is num_draws (per-chain) times num_chains
parts.append(f"Draws: {self._draws.shape[-2] * self._draws.shape[-1]}")
return "\n".join(parts)

def _parameter_indexes(self, param: str) -> Sequence[int]:
"""Obtain indexes for values associated with `param`.
Expand Down
2 changes: 2 additions & 0 deletions tests/test_basic_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_bernoulli_sampling(fit):
assert fit.values.shape[1] == 1000
assert fit.values.shape[2] == 4

assert len(fit) == 1 # one parameter (theta)

# for a fit with only one scalar parameter, it is the last one
assert 0.1 < fit.values[-1, :, 0].mean() < 0.4
assert 0.1 < fit.values[-1, :, 1].mean() < 0.4
Expand Down
1 change: 1 addition & 0 deletions tests/test_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ def posterior():

def test_linear_regression(posterior):
fit = posterior.sample(num_chains=4)
assert len(fit) == 2 # two parameters (beta, sigma)
assert 0 < fit["sigma"].mean() < 2
assert np.allclose(fit["beta"].mean(axis=1), beta_true, atol=0.05)
41 changes: 41 additions & 0 deletions tests/test_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest

import stan

np.random.seed(1)

program_code = """
data {
int<lower=0> N;
int<lower=0> p;
matrix[N,p] x;
vector[N] y;
}
parameters {
vector[p] beta;
real<lower=0> sigma;
}
model {
y ~ normal(x * beta, sigma);
}
"""

n, p = 50, 3 # smaller n than in tests/test_linear_regression.py
X = np.random.normal(size=(n, p))
X = (X - np.mean(X, axis=0)) / np.std(X, ddof=1, axis=0, keepdims=True)
beta_true = (1, 3, 5)
y = np.dot(X, beta_true) + np.random.normal(size=n)

data = {"N": n, "p": p, "x": X, "y": y}


@pytest.fixture(scope="module")
def posterior():
return stan.build(program_code, data=data, random_seed=1)


def test_repr_fit(posterior):
fit = posterior.sample(num_chains=4)
expected = """<stan.Fit>\nParameters:\n beta: (3,)\n sigma: ()\nDraws: 4000"""
assert repr(fit) == expected

0 comments on commit a7d9290

Please sign in to comment.