Skip to content

Commit

Permalink
Merge pull request #612 from stan-dev/fix/complex-output-ordering
Browse files Browse the repository at this point in the history
Fix/complex output ordering
  • Loading branch information
WardBrian committed Aug 25, 2022
2 parents 1f88a4d + 94528b7 commit 997a60b
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 18 deletions.
11 changes: 8 additions & 3 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,15 @@ def stan_variable(
col_idxs = self._metadata.stan_vars_cols[var]
if len(col_idxs) > 0:
dims.extend(self._metadata.stan_vars_dims[var])
# pylint: disable=redundant-keyword-arg
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')

draws = self._draws[draw1:, :, col_idxs]

if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
draws = draws[..., 0] + 1j * draws[..., 1]
draws = draws[..., ::2] + 1j * draws[..., 1::2]
dims = dims[:-1]

draws = draws.reshape(dims, order='F')

return draws

def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
Expand Down
9 changes: 7 additions & 2 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,14 @@ def stan_variable(
col_idxs = self._metadata.stan_vars_cols[var]
if len(col_idxs) > 0:
dims.extend(self._metadata.stan_vars_dims[var])
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
draws = self._draws[draw1:, :, col_idxs]

if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
draws = draws[..., 0] + 1j * draws[..., 1]
draws = draws[..., ::2] + 1j * draws[..., 1::2]
dims = dims[:-1]

draws = draws.reshape(dims, order='F')

return draws

def stan_variables(self) -> Dict[str, np.ndarray]:
Expand Down
11 changes: 8 additions & 3 deletions cmdstanpy/stanfit/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,17 @@ def stan_variable(
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
# pylint: disable=redundant-keyword-arg
if num_rows > 1:
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
result = self._all_iters[:, col_idxs]
else:
result = self._mle[col_idxs].reshape(dims[1:], order="F")
result = self._mle[col_idxs]
dims = dims[1:]

if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
result = result[..., 0] + 1j * result[..., 1]
result = result[..., ::2] + 1j * result[..., 1::2]
dims = dims[:-1]

result = result.reshape(dims, order='F')

return result

else: # scalar var
Expand Down
11 changes: 6 additions & 5 deletions cmdstanpy/stanfit/vb.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,13 @@ def stan_variable(self, var: str) -> Union[np.ndarray, float]:
shape: Tuple[int, ...] = ()
if len(col_idxs) > 1:
shape = self._metadata.stan_vars_dims[var]
result: np.ndarray = np.asarray(self._variational_mean)[
col_idxs
].reshape(shape, order="F")

result: np.ndarray = np.asarray(self._variational_mean)[col_idxs]
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
result = result[..., 0] + 1j * result[..., 1]
result = result[..., ::2] + 1j * result[..., 1::2]
shape = shape[:-1]

result = result.reshape(shape, order="F")

return result
else:
return float(self._variational_mean[col_idxs[0]])
Expand Down
10 changes: 6 additions & 4 deletions cmdstanpy/utils/data_munging.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def build_xarray_data(
if dims:
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))

draws = drawset[start_row:, :, col_idxs].reshape(
*drawset.shape[:2], *dims, order="F"
)
draws = drawset[start_row:, :, col_idxs]

if var_type == BaseType.COMPLEX:
draws = draws[..., 0] + 1j * draws[..., 1]
draws = draws[..., ::2] + 1j * draws[..., 1::2]
var_dims = var_dims[:-1]
dims = dims[:-1]

draws = draws.reshape(*drawset.shape[:2], *dims, order="F")

data[var_name] = (
var_dims,
Expand Down
4 changes: 4 additions & 0 deletions docsrc/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ What's New

For full changes, see the `Releases page <https://github.com/stan-dev/cmdstanpy/releases>`__ on GitHub.

CmdStanPy 1.0.6
---------------

- Fixed an issue where complex number containers in Stan program outputs were not being read in properly by CmdStanPy. The output would have the correct shape, but the values would be mixed up.

CmdStanPy 1.0.6
---------------
Expand Down
4 changes: 3 additions & 1 deletion test/data/complex_var.stan
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ generated quantities {
{{0, 1}, {0, 2}, {0, 3}}};
array[2, 3] complex zs = {{3, 4i, 5}, {1i, 2i, 3i}};
complex z = 3 + 4i;

array[2] int imag = {3, 4};

complex_matrix[2,3] zs_mat = to_matrix(zs);
}

7 changes: 7 additions & 0 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ def test_complex_output(self):

self.assertEqual(fit.stan_variable('zs').shape, (10, 2, 3))
self.assertEqual(fit.stan_variable('z')[0], 3 + 4j)

self.assertTrue(
np.allclose(
fit.stan_variable('zs')[0], np.array([[3, 4j, 5], [1j, 2j, 3j]])
)
)

# make sure the name 'imag' isn't magic
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))

Expand Down
7 changes: 7 additions & 0 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,13 @@ def test_complex_output(self):

self.assertEqual(fit.stan_variable('zs').shape, (2, 3))
self.assertEqual(fit.stan_variable('z'), 3 + 4j)

self.assertTrue(
np.allclose(
fit.stan_variable('zs'), np.array([[3, 4j, 5], [1j, 2j, 3j]])
)
)

# make sure the name 'imag' isn't magic
self.assertEqual(fit.stan_variable('imag').shape, (2,))

Expand Down
16 changes: 16 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,11 +1810,27 @@ def test_complex_output(self):
# make sure the name 'imag' isn't magic
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))

self.assertTrue(
np.allclose(
fit.stan_variable('zs')[0], np.array([[3, 4j, 5], [1j, 2j, 3j]])
)
)
self.assertTrue(
np.allclose(
fit.stan_variable('zs_mat')[0],
np.array([[3, 4j, 5], [1j, 2j, 3j]]),
)
)

self.assertNotIn("zs_dim_2", fit.draws_xr())
# getting a raw scalar out of xarray is heavy
self.assertEqual(
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
)
np.testing.assert_allclose(
fit.draws_xr().zs.isel(chain=0, draw=1).data,
np.array([[3, 4j, 5], [1j, 2j, 3j]]),
)

def test_attrs(self):
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')
Expand Down
7 changes: 7 additions & 0 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def test_complex_output(self):

self.assertEqual(fit.stan_variable('zs').shape, (2, 3))
self.assertEqual(fit.stan_variable('z'), 3 + 4j)

self.assertTrue(
np.allclose(
fit.stan_variable('zs'), np.array([[3, 4j, 5], [1j, 2j, 3j]])
)
)

# make sure the name 'imag' isn't magic
self.assertEqual(fit.stan_variable('imag').shape, (2,))

Expand Down

0 comments on commit 997a60b

Please sign in to comment.