Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions cmdstanpy/stanfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,6 @@ def draws_xr(
0,
self.draws(inc_warmup=inc_warmup),
)

return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
'chain', 'draw', ...
)
Expand Down Expand Up @@ -1373,7 +1372,7 @@ def stan_variable(
inc_iterations: bool = False,
warn: bool = True,
name: Optional[str] = None,
) -> np.ndarray:
) -> Union[np.ndarray, float]:
"""
Return a numpy.ndarray which contains the estimates for the
for the named Stan program variable where the dimensions of the
Expand Down Expand Up @@ -1416,38 +1415,34 @@ def stan_variable(
'Invalid estimate, optimization failed to converge.'
)

col_idxs = self._metadata.stan_vars_cols[var]
col_idxs = list(self._metadata.stan_vars_cols[var])
if inc_iterations and self._save_iterations:
num_rows = self._all_iters.shape[0]
else:
num_rows = 1

if len(col_idxs) > 0: # container var
if len(col_idxs) > 1: # container var
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
# pylint: disable=redundant-keyword-arg
if num_rows > 1:
result = self._all_iters[:, col_idxs].reshape( # type: ignore
dims, order='F'
)
else:
mle = np.expand_dims(self._mle, axis=0) # hack for col indexing
result = (
mle[0, col_idxs]
.reshape(dims, order='F') # type: ignore
.squeeze(axis=0)
)
result = self._mle[col_idxs].reshape(dims[1:], order="F")
else: # scalar var
col_idx = col_idxs[0]
if num_rows > 1:
result = self._all_iters[:, col_idxs]
result = self._all_iters[:, col_idx]
else:
result = np.atleast_1d(mle[0, col_idxs])

assert isinstance(result, np.ndarray) # make the typechecker happy
result = float(self._mle[col_idx])
assert isinstance(
result, (np.ndarray, float)
) # make the typechecker happy
return result

def stan_variables(
self, inc_iterations: bool = False
) -> Dict[str, np.ndarray]:
) -> Dict[str, Union[np.ndarray, float]]:
"""
Return a dictionary mapping Stan program variables names
to the corresponding numpy.ndarray containing the inferred values.
Expand Down Expand Up @@ -1988,16 +1983,26 @@ def stan_variable(
return self.mcmc_sample.stan_variable(var, inc_warmup=inc_warmup)
else: # is gq variable
self._assemble_generated_quantities()
col_idxs = self._metadata.stan_vars_cols[var]
draw1 = 0
if (
not inc_warmup
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
):
draw1 = self.mcmc_sample.num_draws_warmup * self.chains
return flatten_chains(self._draws)[ # type: ignore
draw1:, col_idxs
]
return flatten_chains(self._draws)[:, col_idxs] # type: ignore
draw1 = self.mcmc_sample.num_draws_warmup
num_draws = self.mcmc_sample.num_draws_sampling
if (
inc_warmup
and self.mcmc_sample.metadata.cmdstan_config['save_warmup']
):
num_draws += self.mcmc_sample.num_draws_warmup
dims = [num_draws * self.chains]
col_idxs = self._metadata.stan_vars_cols[var]
if len(col_idxs) > 0:
dims.extend(self._metadata.stan_vars_dims[var])
# pylint: disable=redundant-keyword-arg
return self._draws[draw1:, :, col_idxs].reshape( # type: ignore
dims, order='F'
)

def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
"""
Expand Down Expand Up @@ -2143,7 +2148,7 @@ def metadata(self) -> InferenceMetadata:

def stan_variable(
self, var: Optional[str] = None, *, name: Optional[str] = None
) -> np.ndarray:
) -> Union[np.ndarray, float]:
"""
Return a numpy.ndarray which contains the estimates for the
for the named Stan program variable where the dimensions of the
Expand Down Expand Up @@ -2172,14 +2177,18 @@ def stan_variable(
if var not in self._metadata.stan_vars_dims:
raise ValueError('Unknown variable name: {}'.format(var))
col_idxs = list(self._metadata.stan_vars_cols[var])
vals = list(self._variational_mean)
xs = [vals[x] for x in col_idxs]
shape: Tuple[int, ...] = ()
if len(col_idxs) > 0:
if len(col_idxs) > 1:
shape = self._metadata.stan_vars_dims[var]
return np.array(xs).reshape(shape)
result = np.asarray(self._variational_mean)[col_idxs].reshape(
shape, order="F"
)
else:
result = float(self._variational_mean[col_idxs[0]])
assert isinstance(result, (np.ndarray, float))
return result

def stan_variables(self) -> Dict[str, np.ndarray]:
def stan_variables(self) -> Dict[str, Union[np.ndarray, float]]:
"""
Return a dictionary mapping Stan program variables names
to the corresponding numpy.ndarray containing the inferred values.
Expand Down Expand Up @@ -2424,7 +2433,12 @@ def build_xarray_data(
var_dims: Tuple[str, ...] = ('draw', 'chain')
if dims:
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))
data[var_name] = (var_dims, drawset[start_row:, :, col_idxs])
data[var_name] = (
var_dims,
drawset[start_row:, :, col_idxs].reshape(
*drawset.shape[:2], *dims, order="F"
),
)
else:
data[var_name] = (
var_dims,
Expand Down
4 changes: 3 additions & 1 deletion docsrc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,5 +441,7 @@ def emit(self, record):
# }

# Makes the copying behavior on code examples cleaner by removing things like In [10]: from the text to be copied
copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
copybutton_prompt_text = (
r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
)
copybutton_prompt_is_regexp = True
19 changes: 19 additions & 0 deletions test/data/matrix_var.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
transformed data {
int y[10] = {0,1,0,0,0,0,0,0,0,1};
}
parameters {
real<lower=0,upper=1> theta;
}
model {
theta ~ beta(1,1); // uniform prior on interval 0,1
y ~ bernoulli(theta);
}
generated quantities {
# x is a 4 x 3 matrix where i,j entry == rownum
matrix[4, 3] z;
for (row_num in 1:4) {
for (col_num in 1:3) {
z[row_num, col_num] = row_num;
}
}
}
22 changes: 22 additions & 0 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,28 @@ def test_no_xarray(self):
with self.assertRaises(RuntimeError):
bern_gqs.draws_xr()

def test_single_row_csv(self):
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
bern_model = CmdStanModel(stan_file=stan)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
bern_fit = bern_model.sample(
data=jdata,
chains=1,
seed=12345,
iter_sampling=1,
)
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
model = CmdStanModel(stan_file=stan)
gqs = model.generate_quantities(mcmc_sample=bern_fit)
z_as_ndarray = gqs.stan_variable(var="z")
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
z_as_xr = gqs.draws_xr(vars="z")
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
for i in range(4):
for j in range(3):
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)


if __name__ == '__main__':
unittest.main()
21 changes: 17 additions & 4 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,15 @@ def test_variable_bern(self):
self.assertTrue('theta' in bern_mle.metadata.stan_vars_dims)
self.assertEqual(bern_mle.metadata.stan_vars_dims['theta'], ())
theta = bern_mle.stan_variable(var='theta')
self.assertEqual(theta.shape, ())
self.assertTrue(isinstance(theta, float))
with self.assertRaises(ValueError):
bern_mle.stan_variable(var='eta')
with self.assertRaises(ValueError):
bern_mle.stan_variable(var='lp__')
with LogCapture() as log:
self.assertEqual(bern_mle.stan_variable(name='theta').shape, ())
self.assertTrue(
isinstance(bern_mle.stan_variable(name='theta'), float)
)
log.check_present(
(
'cmdstanpy',
Expand Down Expand Up @@ -250,15 +252,15 @@ def test_variables_3d(self):
var_beta = multidim_mle.stan_variable(var='beta')
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
var_frac_60 = multidim_mle.stan_variable(var='frac_60')
self.assertEqual(var_frac_60.shape, ())
self.assertTrue(isinstance(var_frac_60, float))
vars = multidim_mle.stan_variables()
self.assertEqual(len(vars), len(multidim_mle.metadata.stan_vars_dims))
self.assertTrue('y_rep' in vars)
self.assertEqual(vars['y_rep'].shape, (5, 4, 3))
self.assertTrue('beta' in vars)
self.assertEqual(vars['beta'].shape, (2,))
self.assertTrue('frac_60' in vars)
self.assertEqual(vars['frac_60'].shape, ())
self.assertTrue(isinstance(vars['frac_60'], float))

multidim_mle_iters = multidim_model.optimize(
data=jdata,
Expand Down Expand Up @@ -579,6 +581,17 @@ def test_optimize_bad(self):
data=no_data, seed=1239812093, inits=None, algorithm='BFGS'
)

def test_single_row_csv(self):
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
model = CmdStanModel(stan_file=stan)
mle = model.optimize()
self.assertTrue(isinstance(mle.stan_variable('theta'), float))
z_as_ndarray = mle.stan_variable(var="z")
self.assertEqual(z_as_ndarray.shape, (4, 3))
for i in range(4):
for j in range(3):
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)


if __name__ == '__main__':
unittest.main()
13 changes: 13 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,19 @@ def test_no_xarray(self):
with self.assertRaises(RuntimeError):
bern_fit.draws_xr()

def test_single_row_csv(self):
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
model = CmdStanModel(stan_file=stan)
fit = model.sample(iter_sampling=1, chains=1)
z_as_ndarray = fit.stan_variable(var="z")
self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains
z_as_xr = fit.draws_xr(vars="z")
self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains
for i in range(4):
for j in range(3):
self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)


if __name__ == '__main__':
unittest.main()
15 changes: 13 additions & 2 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_variables_3d(self):
var_beta = multidim_variational.stan_variable(var='beta')
self.assertEqual(var_beta.shape, (2,)) # 1-element tuple
var_frac_60 = multidim_variational.stan_variable(var='frac_60')
self.assertEqual(var_frac_60.shape, ())
self.assertTrue(isinstance(var_frac_60, float))
vars = multidim_variational.stan_variables()
self.assertEqual(
len(vars), len(multidim_variational.metadata.stan_vars_dims)
Expand All @@ -136,7 +136,7 @@ def test_variables_3d(self):
self.assertTrue('beta' in vars)
self.assertEqual(vars['beta'].shape, (2,))
self.assertTrue('frac_60' in vars)
self.assertEqual(vars['frac_60'].shape, ())
self.assertTrue(isinstance(vars['frac_60'], float))
with self.assertRaises(ValueError):
multidim_variational.stan_variable(var='beta', name='yrep')
with LogCapture() as log:
Expand Down Expand Up @@ -253,6 +253,17 @@ def test_variational_eta_fail(self):
)
)

def test_single_row_csv(self):
stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
model = CmdStanModel(stan_file=stan)
vb_fit = model.variational()
self.assertTrue(isinstance(vb_fit.stan_variable('theta'), float))
z_as_ndarray = vb_fit.stan_variable(var="z")
self.assertEqual(z_as_ndarray.shape, (4, 3))
for i in range(4):
for j in range(3):
self.assertEqual(int(z_as_ndarray[i, j]), i + 1)


if __name__ == '__main__':
unittest.main()