diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index 927cefe6..4a523320 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -702,7 +702,7 @@ def draws_as_dataframe( the output, i.e., the sampler was run with ``save_warmup=True``, then the warmup draws are included. Default value is ``False``. """ - pnames_base = [name.split('.')[0] for name in self.column_names] + pnames_base = [name.split('[')[0] for name in self.column_names] if params is not None: for param in params: if not (param in self._column_names or param in pnames_base): @@ -723,27 +723,27 @@ def draws_as_dataframe( mask = [] params = set(params) for name in self.column_names: - if any(item in params for item in (name, name.split('.')[0])): + if any(item in params for item in (name, name.split('[')[0])): mask.append(name) return self._draws_as_df[mask] - def stan_variable(self, name: str) -> np.ndarray: + def stan_variable(self, name: str) -> pd.DataFrame: """ - Return a new ndarray which contains the set of post-warmup draws + Return a new DataFrame which contains the set of post-warmup draws for the named Stan program variable. Flattens the chains. Underlyingly draws are in chain order, i.e., for a sample consisting of N chains of M draws each, the first M array elements are from chain 1, the next M are from chain 2, and the last M elements are from chain N. - * If the variable is a scalar variable, this returns a 1-d array, - length(draws X chains). - * If the variable is a vector, this is a 2-d array, - shape ( draws X chains, len(vector)) - * If the variable is a matrix, this is a 3-d array, - shape ( draws X chains, matrix nrows, matrix ncols ). - * If the variable is an array with N dimensions, this is an N+1-d array, - shape ( draws X chains, size(dim 1), ... size(dim N)). + * If the variable is a scalar variable, the shape of the DataFrame is + ( draws X chains, 1). + * If the variable is a vector, the shape of the DataFrame is + ( draws X chains, len(vector)) + * If the variable is a matrix, the shape of the DataFrame is + ( draws X chains, size(dim 1) X size(dim 2) ) + * If the variable is an array with N dimensions, the shape of the + DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N)) :param name: variable name """ @@ -751,23 +751,19 @@ def stan_variable(self, name: str) -> np.ndarray: raise ValueError('unknown name: {}'.format(name)) self._assemble_draws() dim0 = self.num_draws * self.runset.chains - dims = self._stan_variable_dims[name] - if dims == 1: - idx = self.column_names.index(name) - return self._draws[self._draws_warmup :, :, idx].reshape( - (dim0,), order='A' - ) - else: - idxs = [ - x[0] - for x in enumerate(self.column_names) - if x[1].startswith(name + '.') - ] - var_dims = [dim0] - var_dims.extend(dims) - return self._draws[ - self._draws_warmup :, :, idxs[0] : idxs[-1] + 1 - ].reshape(tuple(var_dims), order='A') + dims = np.prod(self._stan_variable_dims[name]) + pattern = r'^{}(\[[\d,]+\])?$'.format(name) + names, idxs = [], [] + for i, column_name in enumerate(self.column_names): + if re.search(pattern, column_name): + names.append(column_name) + idxs.append(i) + return pd.DataFrame( + self._draws[ + self._draws_warmup:, :, idxs + ].reshape((dim0, dims), order='A'), + columns=names + ) def stan_variables(self) -> Dict: """ diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index eca742bc..97852b7b 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -636,11 +636,19 @@ def scan_column_names(fd: TextIO, config_dict: Dict, lineno: int) -> int: line = fd.readline().strip() lineno += 1 names = line.split(',') - config_dict['column_names'] = tuple(names) + config_dict['column_names'] = tuple(_rename_columns(names)) config_dict['num_params'] = len(names) - 1 return lineno +def _rename_columns(names: List) -> List: + names = [ + re.sub(r',([\d,]+)$', r'[\1]', column.replace('.', ',')) + for column in names + ] + return names + + def parse_var_dims(names: Tuple[str, ...]) -> Dict: """ Use Stan CSV file column names to get variable names, dimensions. @@ -653,14 +661,14 @@ def parse_var_dims(names: Tuple[str, ...]) -> Dict: while idx < len(names): if names[idx].endswith('__'): pass - elif '.' not in names[idx]: + elif '[' not in names[idx]: vars_dict[names[idx]] = 1 else: - vs = names[idx].split('.') - if idx < len(names) - 1 and names[idx + 1].split('.')[0] == vs[0]: + vs = names[idx].split('[') + if idx < len(names) - 1 and names[idx + 1].split('[')[0] == vs[0]: idx += 1 continue - dims = [int(vs[x]) for x in range(1, len(vs))] + dims = [int(x) for x in vs[1][:-1].split(',')] vars_dict[vs[0]] = tuple(dims) idx += 1 return vars_dict diff --git a/test/test_generate_quantities.py b/test/test_generate_quantities.py index 2a165b51..30f6ece5 100644 --- a/test/test_generate_quantities.py +++ b/test/test_generate_quantities.py @@ -37,16 +37,16 @@ def test_gen_quantities_csv_files(self): csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ - 'y_rep.1', - 'y_rep.2', - 'y_rep.3', - 'y_rep.4', - 'y_rep.5', - 'y_rep.6', - 'y_rep.7', - 'y_rep.8', - 'y_rep.9', - 'y_rep.10', + 'y_rep[1]', + 'y_rep[2]', + 'y_rep[3]', + 'y_rep[4]', + 'y_rep[5]', + 'y_rep[6]', + 'y_rep[7]', + 'y_rep[8]', + 'y_rep[9]', + 'y_rep[10]', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual( @@ -104,16 +104,16 @@ def test_gen_quanties_mcmc_sample(self): csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ - 'y_rep.1', - 'y_rep.2', - 'y_rep.3', - 'y_rep.4', - 'y_rep.5', - 'y_rep.6', - 'y_rep.7', - 'y_rep.8', - 'y_rep.9', - 'y_rep.10', + 'y_rep[1]', + 'y_rep[2]', + 'y_rep[3]', + 'y_rep[4]', + 'y_rep[5]', + 'y_rep[6]', + 'y_rep[7]', + 'y_rep[8]', + 'y_rep[9]', + 'y_rep[10]', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual( diff --git a/test/test_sample.py b/test/test_sample.py index e0edfe4c..c090c397 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -341,88 +341,88 @@ def test_fixed_param_good(self): 'lp__', 'accept_stat__', 'N', - 'y_sim.1', - 'y_sim.2', - 'y_sim.3', - 'y_sim.4', - 'y_sim.5', - 'y_sim.6', - 'y_sim.7', - 'y_sim.8', - 'y_sim.9', - 'y_sim.10', - 'y_sim.11', - 'y_sim.12', - 'y_sim.13', - 'y_sim.14', - 'y_sim.15', - 'y_sim.16', - 'y_sim.17', - 'y_sim.18', - 'y_sim.19', - 'y_sim.20', - 'x_sim.1', - 'x_sim.2', - 'x_sim.3', - 'x_sim.4', - 'x_sim.5', - 'x_sim.6', - 'x_sim.7', - 'x_sim.8', - 'x_sim.9', - 'x_sim.10', - 'x_sim.11', - 'x_sim.12', - 'x_sim.13', - 'x_sim.14', - 'x_sim.15', - 'x_sim.16', - 'x_sim.17', - 'x_sim.18', - 'x_sim.19', - 'x_sim.20', - 'pop_sim.1', - 'pop_sim.2', - 'pop_sim.3', - 'pop_sim.4', - 'pop_sim.5', - 'pop_sim.6', - 'pop_sim.7', - 'pop_sim.8', - 'pop_sim.9', - 'pop_sim.10', - 'pop_sim.11', - 'pop_sim.12', - 'pop_sim.13', - 'pop_sim.14', - 'pop_sim.15', - 'pop_sim.16', - 'pop_sim.17', - 'pop_sim.18', - 'pop_sim.19', - 'pop_sim.20', + 'y_sim[1]', + 'y_sim[2]', + 'y_sim[3]', + 'y_sim[4]', + 'y_sim[5]', + 'y_sim[6]', + 'y_sim[7]', + 'y_sim[8]', + 'y_sim[9]', + 'y_sim[10]', + 'y_sim[11]', + 'y_sim[12]', + 'y_sim[13]', + 'y_sim[14]', + 'y_sim[15]', + 'y_sim[16]', + 'y_sim[17]', + 'y_sim[18]', + 'y_sim[19]', + 'y_sim[20]', + 'x_sim[1]', + 'x_sim[2]', + 'x_sim[3]', + 'x_sim[4]', + 'x_sim[5]', + 'x_sim[6]', + 'x_sim[7]', + 'x_sim[8]', + 'x_sim[9]', + 'x_sim[10]', + 'x_sim[11]', + 'x_sim[12]', + 'x_sim[13]', + 'x_sim[14]', + 'x_sim[15]', + 'x_sim[16]', + 'x_sim[17]', + 'x_sim[18]', + 'x_sim[19]', + 'x_sim[20]', + 'pop_sim[1]', + 'pop_sim[2]', + 'pop_sim[3]', + 'pop_sim[4]', + 'pop_sim[5]', + 'pop_sim[6]', + 'pop_sim[7]', + 'pop_sim[8]', + 'pop_sim[9]', + 'pop_sim[10]', + 'pop_sim[11]', + 'pop_sim[12]', + 'pop_sim[13]', + 'pop_sim[14]', + 'pop_sim[15]', + 'pop_sim[16]', + 'pop_sim[17]', + 'pop_sim[18]', + 'pop_sim[19]', + 'pop_sim[20]', 'alpha_sim', 'beta_sim', - 'eta.1', - 'eta.2', - 'eta.3', - 'eta.4', - 'eta.5', - 'eta.6', - 'eta.7', - 'eta.8', - 'eta.9', - 'eta.10', - 'eta.11', - 'eta.12', - 'eta.13', - 'eta.14', - 'eta.15', - 'eta.16', - 'eta.17', - 'eta.18', - 'eta.19', - 'eta.20', + 'eta[1]', + 'eta[2]', + 'eta[3]', + 'eta[4]', + 'eta[5]', + 'eta[6]', + 'eta[7]', + 'eta[8]', + 'eta[9]', + 'eta[10]', + 'eta[11]', + 'eta[12]', + 'eta[13]', + 'eta[14]', + 'eta[15]', + 'eta[16]', + 'eta[17]', + 'eta[18]', + 'eta[19]', + 'eta[20]', ] self.assertEqual(datagen_fit.column_names, tuple(column_names)) self.assertEqual(datagen_fit.num_draws, 100) @@ -527,7 +527,7 @@ def test_validate_big_run(self): os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), ] fit = CmdStanMCMC(runset) - phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)] + phis = ['phi[{}]'.format(str(x + 1)) for x in range(2095)] column_names = SAMPLER_STATE + phis self.assertEqual(fit.num_draws, 1000) self.assertEqual(fit.column_names, tuple(column_names)) @@ -537,14 +537,18 @@ def test_validate_big_run(self): self.assertEqual((1000, 2, 2102), fit.draws().shape) phis = fit.draws_as_dataframe(params=['phi']) self.assertEqual((2000, 2095), phis.shape) - phi1 = fit.draws_as_dataframe(params=['phi.1']) + phi1 = fit.draws_as_dataframe(params=['phi[1]']) self.assertEqual((2000, 1), phi1.shape) - mo_phis = fit.draws_as_dataframe(params=['phi.1', 'phi.10', 'phi.100']) + mo_phis = fit.draws_as_dataframe( + params=['phi[1]', 'phi[10]', 'phi[100]'] + ) self.assertEqual((2000, 3), mo_phis.shape) - phi2095 = fit.draws_as_dataframe(params=['phi.2095']) + phi2095 = fit.draws_as_dataframe(params=['phi[2095]']) self.assertEqual((2000, 1), phi2095.shape) - with self.assertRaisesRegex(ValueError, 'unknown parameter: phi.2096'): - fit.draws_as_dataframe(params=['phi.2096']) + with self.assertRaisesRegex( + ValueError, r'unknown parameter: phi\[2096\]' + ): + fit.draws_as_dataframe(params=['phi[2096]']) with self.assertRaisesRegex(ValueError, 'unknown parameter: ph'): fit.draws_as_dataframe(params=['ph']) @@ -889,7 +893,7 @@ def test_variable_bern(self): self.assertTrue('theta' in bern_fit._stan_variable_dims) self.assertEqual(bern_fit._stan_variable_dims['theta'], 1) theta = bern_fit.stan_variable(name='theta') - self.assertEqual(theta.shape, (200,)) + self.assertEqual(theta.shape, (200, 1)) with self.assertRaises(ValueError): bern_fit.stan_variable(name='eta') with self.assertRaises(ValueError): @@ -919,7 +923,7 @@ def test_variable_lv(self): self.assertTrue('z' in fit._stan_variable_dims) self.assertEqual(fit._stan_variable_dims['z'], (20, 2)) z = fit.stan_variable(name='z') - self.assertEqual(z.shape, (20, 20, 2)) + self.assertEqual(z.shape, (20, 40)) theta = fit.stan_variable(name='theta') self.assertEqual(theta.shape, (20, 4)) @@ -948,7 +952,7 @@ def test_variables(self): vars = fit.stan_variables() self.assertEqual(len(vars), len(fit._stan_variable_dims)) self.assertTrue('z' in vars) - self.assertEqual(vars['z'].shape, (20, 20, 2)) + self.assertEqual(vars['z'].shape, (20, 40)) self.assertTrue('theta' in vars) self.assertEqual(vars['theta'].shape, (20, 4)) diff --git a/test/test_utils.py b/test/test_utils.py index 49270c58..b63eedd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -516,16 +516,16 @@ def test_parse_scalars(self): def test_parse_scalar_vec_scalar(self): x = [ 'foo', - 'phi.1', - 'phi.2', - 'phi.3', - 'phi.4', - 'phi.5', - 'phi.6', - 'phi.7', - 'phi.8', - 'phi.9', - 'phi.10', + 'phi[1]', + 'phi[2]', + 'phi[3]', + 'phi[4]', + 'phi[5]', + 'phi[6]', + 'phi[7]', + 'phi[8]', + 'phi[9]', + 'phi[10]', 'bar', ] vars_dict = parse_var_dims(x) @@ -537,18 +537,18 @@ def test_parse_scalar_vec_scalar(self): def test_parse_scalar_matrix_vec(self): x = [ 'foo', - 'phi.1.1', - 'phi.1.2', - 'phi.1.3', - 'phi.1.4', - 'phi.1.5', - 'phi.2.1', - 'phi.2.2', - 'phi.2.3', - 'phi.2.4', - 'phi.2.5', - 'bar.1', - 'bar.2', + 'phi[1,1]', + 'phi[1,2]', + 'phi[1,3]', + 'phi[1,4]', + 'phi[1,5]', + 'phi[2,1]', + 'phi[2,2]', + 'phi[2,3]', + 'phi[2,4]', + 'phi[2,5]', + 'bar[1]', + 'bar[2]', ] vars_dict = parse_var_dims(x) self.assertEqual(len(vars_dict), 3) diff --git a/test/test_variational.py b/test/test_variational.py index b098b561..0aa576f1 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -51,13 +51,13 @@ def test_instantiate(self): self.assertIn('method=variational', variational.__repr__()) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( - variational.variational_params_dict['mu.1'], 31.0299, places=2 + variational.variational_params_dict['mu[1]'], 31.0299, places=2 ) self.assertAlmostEqual( - variational.variational_params_dict['mu.2'], 28.8141, places=2 + variational.variational_params_dict['mu[2]'], 28.8141, places=2 ) self.assertEqual(variational.variational_sample.shape, (1000, 5)) @@ -84,7 +84,7 @@ def test_variational_good(self): variational = model.variational(algorithm='meanfield', seed=12345) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( @@ -95,10 +95,10 @@ def test_variational_good(self): ) self.assertAlmostEqual( - variational.variational_params_dict['mu.1'], 31.0418, places=2 + variational.variational_params_dict['mu[1]'], 31.0418, places=2 ) self.assertAlmostEqual( - variational.variational_params_dict['mu.2'], 27.4463, places=2 + variational.variational_params_dict['mu[2]'], 27.4463, places=2 ) self.assertEqual( @@ -107,11 +107,11 @@ def test_variational_good(self): ) self.assertEqual( variational.variational_params_np[3], - variational.variational_params_pd['mu.1'][0], + variational.variational_params_pd['mu[1]'][0], ) self.assertEqual( variational.variational_params_np[4], - variational.variational_params_pd['mu.2'][0], + variational.variational_params_pd['mu[2]'][0], ) self.assertEqual(variational.variational_sample.shape, (1000, 5)) @@ -127,13 +127,13 @@ def test_variational_eta_small(self): variational = model.variational(algorithm='meanfield', seed=12345) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( - fabs(variational.variational_params_dict['mu.1']), 0.08, places=1 + fabs(variational.variational_params_dict['mu[1]']), 0.08, places=1 ) self.assertAlmostEqual( - fabs(variational.variational_params_dict['mu.2']), 0.09, places=1 + fabs(variational.variational_params_dict['mu[2]']), 0.09, places=1 ) self.assertTrue(True)