From 3f7813752a741aff2f8faa96c0f43c76010380a8 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Fri, 4 Sep 2020 19:05:28 +0200 Subject: [PATCH 1/7] stan_variable now returns a pd.DataFrame --- cmdstanpy/stanfit.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index 927cefe6..c64d0bdb 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -727,22 +727,23 @@ def draws_as_dataframe( mask.append(name) return self._draws_as_df[mask] - def stan_variable(self, name: str) -> np.ndarray: + def stan_variable(self, name: str) -> pd.DataFrame: """ - Return a new ndarray which contains the set of post-warmup draws + Return a new DataFrame which contains the set of post-warmup draws for the named Stan program variable. Flattens the chains. Underlyingly draws are in chain order, i.e., for a sample consisting of N chains of M draws each, the first M array elements are from chain 1, the next M are from chain 2, and the last M elements are from chain N. - * If the variable is a scalar variable, this returns a 1-d array, - length(draws X chains). + * If the variable is a scalar variable, this returns a 2-d DataFrame + with a singleton dimension, + shape( draws X chains, 1). * If the variable is a vector, this is a 2-d array, shape ( draws X chains, len(vector)) * If the variable is a matrix, this is a 3-d array, shape ( draws X chains, matrix nrows, matrix ncols ). - * If the variable is an array with N dimensions, this is an N+1-d array, + * If the variable is an array with N dimensions, this is an N+1-d array shape ( draws X chains, size(dim 1), ... size(dim N)). :param name: variable name @@ -754,20 +755,23 @@ def stan_variable(self, name: str) -> np.ndarray: dims = self._stan_variable_dims[name] if dims == 1: idx = self.column_names.index(name) - return self._draws[self._draws_warmup :, :, idx].reshape( - (dim0,), order='A' - ) + return pd.DataFrame({ + name: self._draws[self._draws_warmup:, :, idx].reshape( + (dim0,), order='A' + ) + }) else: idxs = [ - x[0] + x for x in enumerate(self.column_names) if x[1].startswith(name + '.') ] - var_dims = [dim0] - var_dims.extend(dims) - return self._draws[ - self._draws_warmup :, :, idxs[0] : idxs[-1] + 1 - ].reshape(tuple(var_dims), order='A') + return pd.DataFrame({ + n: self._draws[ + self._draws_warmup:, :, x + ].reshape(dim0, order='A') + for x, n in idxs + }) def stan_variables(self) -> Dict: """ From 0c48ec57615ff932ca887ae006f0b6a938580941 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Fri, 4 Sep 2020 21:18:04 +0200 Subject: [PATCH 2/7] Updated description and adapted unit tests to new output of stan_variable --- cmdstanpy/stanfit.py | 15 +++++++-------- test/test_sample.py | 6 +++--- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index c64d0bdb..302f18d4 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -736,15 +736,14 @@ def stan_variable(self, name: str) -> pd.DataFrame: elements are from chain 1, the next M are from chain 2, and the last M elements are from chain N. - * If the variable is a scalar variable, this returns a 2-d DataFrame - with a singleton dimension, - shape( draws X chains, 1). - * If the variable is a vector, this is a 2-d array, + * If the variable is a scalar variable, the shape of the DataFrame is + ( draws X chains, 1). + * If the variable is a vector, the shape of the DataFrame is shape ( draws X chains, len(vector)) - * If the variable is a matrix, this is a 3-d array, - shape ( draws X chains, matrix nrows, matrix ncols ). - * If the variable is an array with N dimensions, this is an N+1-d array - shape ( draws X chains, size(dim 1), ... size(dim N)). + * If the variable is a matrix, the shape of the DataFrame is + shape ( draws X chains, size(dim 1) X size(dim 2) ) + * If the variable is an array with N dimensions, the shape of the DataFrame is + shape ( draws X chains, size(dim 1) X ... X size(dim N)). :param name: variable name """ diff --git a/test/test_sample.py b/test/test_sample.py index e0edfe4c..0252164d 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -889,7 +889,7 @@ def test_variable_bern(self): self.assertTrue('theta' in bern_fit._stan_variable_dims) self.assertEqual(bern_fit._stan_variable_dims['theta'], 1) theta = bern_fit.stan_variable(name='theta') - self.assertEqual(theta.shape, (200,)) + self.assertEqual(theta.shape, (200, 1)) with self.assertRaises(ValueError): bern_fit.stan_variable(name='eta') with self.assertRaises(ValueError): @@ -919,7 +919,7 @@ def test_variable_lv(self): self.assertTrue('z' in fit._stan_variable_dims) self.assertEqual(fit._stan_variable_dims['z'], (20, 2)) z = fit.stan_variable(name='z') - self.assertEqual(z.shape, (20, 20, 2)) + self.assertEqual(z.shape, (20, 40)) theta = fit.stan_variable(name='theta') self.assertEqual(theta.shape, (20, 4)) @@ -948,7 +948,7 @@ def test_variables(self): vars = fit.stan_variables() self.assertEqual(len(vars), len(fit._stan_variable_dims)) self.assertTrue('z' in vars) - self.assertEqual(vars['z'].shape, (20, 20, 2)) + self.assertEqual(vars['z'].shape, (20, 40)) self.assertTrue('theta' in vars) self.assertEqual(vars['theta'].shape, (20, 4)) From 98a64578b1eb20d6549ca1cb5840a3370d4a8b65 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Sat, 5 Sep 2020 14:21:05 +0200 Subject: [PATCH 3/7] stan_variable now in more streamlined version according to Ari's suggestions --- cmdstanpy/stanfit.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index 302f18d4..5187ad52 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -751,26 +751,19 @@ def stan_variable(self, name: str) -> pd.DataFrame: raise ValueError('unknown name: {}'.format(name)) self._assemble_draws() dim0 = self.num_draws * self.runset.chains - dims = self._stan_variable_dims[name] - if dims == 1: - idx = self.column_names.index(name) - return pd.DataFrame({ - name: self._draws[self._draws_warmup:, :, idx].reshape( - (dim0,), order='A' - ) - }) - else: - idxs = [ - x - for x in enumerate(self.column_names) - if x[1].startswith(name + '.') - ] - return pd.DataFrame({ - n: self._draws[ - self._draws_warmup:, :, x - ].reshape(dim0, order='A') - for x, n in idxs - }) + dims = np.prod(self._stan_variable_dims[name]) + pattern = r'^{}(\.\d+)*$'.format(name) + names, idxs = [], [] + for i, column_name in enumerate(self.column_names): + if re.search(pattern, column_name): + names.append(column_name) + idxs.append(i) + return pd.DataFrame( + self._draws[ + self._draws_warmup:, :, idxs + ].reshape((dim0, dims), order='A'), + columns=names + ) def stan_variables(self) -> Dict: """ From 6406aeefead5a5ff1e4d5956b06d84da7a06d611 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Sun, 13 Sep 2020 21:05:47 +0200 Subject: [PATCH 4/7] Columns are being renamed, adapted tests --- cmdstanpy/stanfit.py | 4 +- cmdstanpy/utils.py | 18 +++- test/test_generate_quantities.py | 40 ++++---- test/test_sample.py | 160 +++++++++++++++---------------- test/test_utils.py | 44 ++++----- test/test_variational.py | 22 ++--- 6 files changed, 148 insertions(+), 140 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index 5187ad52..8d7c3913 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -702,7 +702,7 @@ def draws_as_dataframe( the output, i.e., the sampler was run with ``save_warmup=True``, then the warmup draws are included. Default value is ``False``. """ - pnames_base = [name.split('.')[0] for name in self.column_names] + pnames_base = [name.split('[')[0] for name in self.column_names] if params is not None: for param in params: if not (param in self._column_names or param in pnames_base): @@ -723,7 +723,7 @@ def draws_as_dataframe( mask = [] params = set(params) for name in self.column_names: - if any(item in params for item in (name, name.split('.')[0])): + if any(item in params for item in (name, name.split('[')[0])): mask.append(name) return self._draws_as_df[mask] diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index eca742bc..97852b7b 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -636,11 +636,19 @@ def scan_column_names(fd: TextIO, config_dict: Dict, lineno: int) -> int: line = fd.readline().strip() lineno += 1 names = line.split(',') - config_dict['column_names'] = tuple(names) + config_dict['column_names'] = tuple(_rename_columns(names)) config_dict['num_params'] = len(names) - 1 return lineno +def _rename_columns(names: List) -> List: + names = [ + re.sub(r',([\d,]+)$', r'[\1]', column.replace('.', ',')) + for column in names + ] + return names + + def parse_var_dims(names: Tuple[str, ...]) -> Dict: """ Use Stan CSV file column names to get variable names, dimensions. @@ -653,14 +661,14 @@ def parse_var_dims(names: Tuple[str, ...]) -> Dict: while idx < len(names): if names[idx].endswith('__'): pass - elif '.' not in names[idx]: + elif '[' not in names[idx]: vars_dict[names[idx]] = 1 else: - vs = names[idx].split('.') - if idx < len(names) - 1 and names[idx + 1].split('.')[0] == vs[0]: + vs = names[idx].split('[') + if idx < len(names) - 1 and names[idx + 1].split('[')[0] == vs[0]: idx += 1 continue - dims = [int(vs[x]) for x in range(1, len(vs))] + dims = [int(x) for x in vs[1][:-1].split(',')] vars_dict[vs[0]] = tuple(dims) idx += 1 return vars_dict diff --git a/test/test_generate_quantities.py b/test/test_generate_quantities.py index 2a165b51..30f6ece5 100644 --- a/test/test_generate_quantities.py +++ b/test/test_generate_quantities.py @@ -37,16 +37,16 @@ def test_gen_quantities_csv_files(self): csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ - 'y_rep.1', - 'y_rep.2', - 'y_rep.3', - 'y_rep.4', - 'y_rep.5', - 'y_rep.6', - 'y_rep.7', - 'y_rep.8', - 'y_rep.9', - 'y_rep.10', + 'y_rep[1]', + 'y_rep[2]', + 'y_rep[3]', + 'y_rep[4]', + 'y_rep[5]', + 'y_rep[6]', + 'y_rep[7]', + 'y_rep[8]', + 'y_rep[9]', + 'y_rep[10]', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual( @@ -104,16 +104,16 @@ def test_gen_quanties_mcmc_sample(self): csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ - 'y_rep.1', - 'y_rep.2', - 'y_rep.3', - 'y_rep.4', - 'y_rep.5', - 'y_rep.6', - 'y_rep.7', - 'y_rep.8', - 'y_rep.9', - 'y_rep.10', + 'y_rep[1]', + 'y_rep[2]', + 'y_rep[3]', + 'y_rep[4]', + 'y_rep[5]', + 'y_rep[6]', + 'y_rep[7]', + 'y_rep[8]', + 'y_rep[9]', + 'y_rep[10]', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual( diff --git a/test/test_sample.py b/test/test_sample.py index 0252164d..a7d46caa 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -341,88 +341,88 @@ def test_fixed_param_good(self): 'lp__', 'accept_stat__', 'N', - 'y_sim.1', - 'y_sim.2', - 'y_sim.3', - 'y_sim.4', - 'y_sim.5', - 'y_sim.6', - 'y_sim.7', - 'y_sim.8', - 'y_sim.9', - 'y_sim.10', - 'y_sim.11', - 'y_sim.12', - 'y_sim.13', - 'y_sim.14', - 'y_sim.15', - 'y_sim.16', - 'y_sim.17', - 'y_sim.18', - 'y_sim.19', - 'y_sim.20', - 'x_sim.1', - 'x_sim.2', - 'x_sim.3', - 'x_sim.4', - 'x_sim.5', - 'x_sim.6', - 'x_sim.7', - 'x_sim.8', - 'x_sim.9', - 'x_sim.10', - 'x_sim.11', - 'x_sim.12', - 'x_sim.13', - 'x_sim.14', - 'x_sim.15', - 'x_sim.16', - 'x_sim.17', - 'x_sim.18', - 'x_sim.19', - 'x_sim.20', - 'pop_sim.1', - 'pop_sim.2', - 'pop_sim.3', - 'pop_sim.4', - 'pop_sim.5', - 'pop_sim.6', - 'pop_sim.7', - 'pop_sim.8', - 'pop_sim.9', - 'pop_sim.10', - 'pop_sim.11', - 'pop_sim.12', - 'pop_sim.13', - 'pop_sim.14', - 'pop_sim.15', - 'pop_sim.16', - 'pop_sim.17', - 'pop_sim.18', - 'pop_sim.19', - 'pop_sim.20', + 'y_sim[1]', + 'y_sim[2]', + 'y_sim[3]', + 'y_sim[4]', + 'y_sim[5]', + 'y_sim[6]', + 'y_sim[7]', + 'y_sim[8]', + 'y_sim[9]', + 'y_sim[10]', + 'y_sim[11]', + 'y_sim[12]', + 'y_sim[13]', + 'y_sim[14]', + 'y_sim[15]', + 'y_sim[16]', + 'y_sim[17]', + 'y_sim[18]', + 'y_sim[19]', + 'y_sim[20]', + 'x_sim[1]', + 'x_sim[2]', + 'x_sim[3]', + 'x_sim[4]', + 'x_sim[5]', + 'x_sim[6]', + 'x_sim[7]', + 'x_sim[8]', + 'x_sim[9]', + 'x_sim[10]', + 'x_sim[11]', + 'x_sim[12]', + 'x_sim[13]', + 'x_sim[14]', + 'x_sim[15]', + 'x_sim[16]', + 'x_sim[17]', + 'x_sim[18]', + 'x_sim[19]', + 'x_sim[20]', + 'pop_sim[1]', + 'pop_sim[2]', + 'pop_sim[3]', + 'pop_sim[4]', + 'pop_sim[5]', + 'pop_sim[6]', + 'pop_sim[7]', + 'pop_sim[8]', + 'pop_sim[9]', + 'pop_sim[10]', + 'pop_sim[11]', + 'pop_sim[12]', + 'pop_sim[13]', + 'pop_sim[14]', + 'pop_sim[15]', + 'pop_sim[16]', + 'pop_sim[17]', + 'pop_sim[18]', + 'pop_sim[19]', + 'pop_sim[20]', 'alpha_sim', 'beta_sim', - 'eta.1', - 'eta.2', - 'eta.3', - 'eta.4', - 'eta.5', - 'eta.6', - 'eta.7', - 'eta.8', - 'eta.9', - 'eta.10', - 'eta.11', - 'eta.12', - 'eta.13', - 'eta.14', - 'eta.15', - 'eta.16', - 'eta.17', - 'eta.18', - 'eta.19', - 'eta.20', + 'eta[1]', + 'eta[2]', + 'eta[3]', + 'eta[4]', + 'eta[5]', + 'eta[6]', + 'eta[7]', + 'eta[8]', + 'eta[9]', + 'eta[10]', + 'eta[11]', + 'eta[12]', + 'eta[13]', + 'eta[14]', + 'eta[15]', + 'eta[16]', + 'eta[17]', + 'eta[18]', + 'eta[19]', + 'eta[20]', ] self.assertEqual(datagen_fit.column_names, tuple(column_names)) self.assertEqual(datagen_fit.num_draws, 100) diff --git a/test/test_utils.py b/test/test_utils.py index 49270c58..b63eedd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -516,16 +516,16 @@ def test_parse_scalars(self): def test_parse_scalar_vec_scalar(self): x = [ 'foo', - 'phi.1', - 'phi.2', - 'phi.3', - 'phi.4', - 'phi.5', - 'phi.6', - 'phi.7', - 'phi.8', - 'phi.9', - 'phi.10', + 'phi[1]', + 'phi[2]', + 'phi[3]', + 'phi[4]', + 'phi[5]', + 'phi[6]', + 'phi[7]', + 'phi[8]', + 'phi[9]', + 'phi[10]', 'bar', ] vars_dict = parse_var_dims(x) @@ -537,18 +537,18 @@ def test_parse_scalar_vec_scalar(self): def test_parse_scalar_matrix_vec(self): x = [ 'foo', - 'phi.1.1', - 'phi.1.2', - 'phi.1.3', - 'phi.1.4', - 'phi.1.5', - 'phi.2.1', - 'phi.2.2', - 'phi.2.3', - 'phi.2.4', - 'phi.2.5', - 'bar.1', - 'bar.2', + 'phi[1,1]', + 'phi[1,2]', + 'phi[1,3]', + 'phi[1,4]', + 'phi[1,5]', + 'phi[2,1]', + 'phi[2,2]', + 'phi[2,3]', + 'phi[2,4]', + 'phi[2,5]', + 'bar[1]', + 'bar[2]', ] vars_dict = parse_var_dims(x) self.assertEqual(len(vars_dict), 3) diff --git a/test/test_variational.py b/test/test_variational.py index b098b561..0aa576f1 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -51,13 +51,13 @@ def test_instantiate(self): self.assertIn('method=variational', variational.__repr__()) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( - variational.variational_params_dict['mu.1'], 31.0299, places=2 + variational.variational_params_dict['mu[1]'], 31.0299, places=2 ) self.assertAlmostEqual( - variational.variational_params_dict['mu.2'], 28.8141, places=2 + variational.variational_params_dict['mu[2]'], 28.8141, places=2 ) self.assertEqual(variational.variational_sample.shape, (1000, 5)) @@ -84,7 +84,7 @@ def test_variational_good(self): variational = model.variational(algorithm='meanfield', seed=12345) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( @@ -95,10 +95,10 @@ def test_variational_good(self): ) self.assertAlmostEqual( - variational.variational_params_dict['mu.1'], 31.0418, places=2 + variational.variational_params_dict['mu[1]'], 31.0418, places=2 ) self.assertAlmostEqual( - variational.variational_params_dict['mu.2'], 27.4463, places=2 + variational.variational_params_dict['mu[2]'], 27.4463, places=2 ) self.assertEqual( @@ -107,11 +107,11 @@ def test_variational_good(self): ) self.assertEqual( variational.variational_params_np[3], - variational.variational_params_pd['mu.1'][0], + variational.variational_params_pd['mu[1]'][0], ) self.assertEqual( variational.variational_params_np[4], - variational.variational_params_pd['mu.2'][0], + variational.variational_params_pd['mu[2]'][0], ) self.assertEqual(variational.variational_sample.shape, (1000, 5)) @@ -127,13 +127,13 @@ def test_variational_eta_small(self): variational = model.variational(algorithm='meanfield', seed=12345) self.assertEqual( variational.column_names, - ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'), + ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'), ) self.assertAlmostEqual( - fabs(variational.variational_params_dict['mu.1']), 0.08, places=1 + fabs(variational.variational_params_dict['mu[1]']), 0.08, places=1 ) self.assertAlmostEqual( - fabs(variational.variational_params_dict['mu.2']), 0.09, places=1 + fabs(variational.variational_params_dict['mu[2]']), 0.09, places=1 ) self.assertTrue(True) From c2c5de59c95d09db4ff4599257443b3d571db9f9 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Mon, 14 Sep 2020 22:26:43 +0200 Subject: [PATCH 5/7] Updated regex patterns to search for new column names --- cmdstanpy/stanfit.py | 2 +- test/test_sample.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index 8d7c3913..a960d222 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -752,7 +752,7 @@ def stan_variable(self, name: str) -> pd.DataFrame: self._assemble_draws() dim0 = self.num_draws * self.runset.chains dims = np.prod(self._stan_variable_dims[name]) - pattern = r'^{}(\.\d+)*$'.format(name) + pattern = r'^{}(\[[\d,]+\])?$'.format(name) names, idxs = [], [] for i, column_name in enumerate(self.column_names): if re.search(pattern, column_name): diff --git a/test/test_sample.py b/test/test_sample.py index a7d46caa..1e8dd356 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -527,7 +527,7 @@ def test_validate_big_run(self): os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), ] fit = CmdStanMCMC(runset) - phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)] + phis = ['phi[{}]'.format(str(x + 1)) for x in range(2095)] column_names = SAMPLER_STATE + phis self.assertEqual(fit.num_draws, 1000) self.assertEqual(fit.column_names, tuple(column_names)) @@ -537,14 +537,14 @@ def test_validate_big_run(self): self.assertEqual((1000, 2, 2102), fit.draws().shape) phis = fit.draws_as_dataframe(params=['phi']) self.assertEqual((2000, 2095), phis.shape) - phi1 = fit.draws_as_dataframe(params=['phi.1']) + phi1 = fit.draws_as_dataframe(params=['phi[1]']) self.assertEqual((2000, 1), phi1.shape) - mo_phis = fit.draws_as_dataframe(params=['phi.1', 'phi.10', 'phi.100']) + mo_phis = fit.draws_as_dataframe(params=['phi[1]', 'phi[10]', 'phi[100]']) self.assertEqual((2000, 3), mo_phis.shape) - phi2095 = fit.draws_as_dataframe(params=['phi.2095']) + phi2095 = fit.draws_as_dataframe(params=['phi[2095]']) self.assertEqual((2000, 1), phi2095.shape) - with self.assertRaisesRegex(ValueError, 'unknown parameter: phi.2096'): - fit.draws_as_dataframe(params=['phi.2096']) + with self.assertRaisesRegex(ValueError, 'unknown parameter: phi\[2096\]'): + fit.draws_as_dataframe(params=['phi[2096]']) with self.assertRaisesRegex(ValueError, 'unknown parameter: ph'): fit.draws_as_dataframe(params=['ph']) From beef6b0edd4c11385db0380fccaa36ca54d678bf Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Tue, 15 Sep 2020 00:34:07 +0200 Subject: [PATCH 6/7] Minor formatting to make flake8 and pylint happy --- cmdstanpy/stanfit.py | 4 ++-- test/test_sample.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index a960d222..cbbb6f4a 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -742,8 +742,8 @@ def stan_variable(self, name: str) -> pd.DataFrame: shape ( draws X chains, len(vector)) * If the variable is a matrix, the shape of the DataFrame is shape ( draws X chains, size(dim 1) X size(dim 2) ) - * If the variable is an array with N dimensions, the shape of the DataFrame is - shape ( draws X chains, size(dim 1) X ... X size(dim N)). + * If the variable is an array with N dimensions, the shape of the + DataFrame is shape ( draws X chains, size(dim 1) X ... X size(dim N)) :param name: variable name """ diff --git a/test/test_sample.py b/test/test_sample.py index 1e8dd356..c090c397 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -537,13 +537,17 @@ def test_validate_big_run(self): self.assertEqual((1000, 2, 2102), fit.draws().shape) phis = fit.draws_as_dataframe(params=['phi']) self.assertEqual((2000, 2095), phis.shape) - phi1 = fit.draws_as_dataframe(params=['phi[1]']) + phi1 = fit.draws_as_dataframe(params=['phi[1]']) self.assertEqual((2000, 1), phi1.shape) - mo_phis = fit.draws_as_dataframe(params=['phi[1]', 'phi[10]', 'phi[100]']) + mo_phis = fit.draws_as_dataframe( + params=['phi[1]', 'phi[10]', 'phi[100]'] + ) self.assertEqual((2000, 3), mo_phis.shape) phi2095 = fit.draws_as_dataframe(params=['phi[2095]']) self.assertEqual((2000, 1), phi2095.shape) - with self.assertRaisesRegex(ValueError, 'unknown parameter: phi\[2096\]'): + with self.assertRaisesRegex( + ValueError, r'unknown parameter: phi\[2096\]' + ): fit.draws_as_dataframe(params=['phi[2096]']) with self.assertRaisesRegex(ValueError, 'unknown parameter: ph'): fit.draws_as_dataframe(params=['ph']) From eddf81fc23db47a748312c4073fa069768af7935 Mon Sep 17 00:00:00 2001 From: Lukas Neugebauer Date: Tue, 15 Sep 2020 00:37:53 +0200 Subject: [PATCH 7/7] deleted redundant 'shape' --- cmdstanpy/stanfit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmdstanpy/stanfit.py b/cmdstanpy/stanfit.py index cbbb6f4a..4a523320 100644 --- a/cmdstanpy/stanfit.py +++ b/cmdstanpy/stanfit.py @@ -739,11 +739,11 @@ def stan_variable(self, name: str) -> pd.DataFrame: * If the variable is a scalar variable, the shape of the DataFrame is ( draws X chains, 1). * If the variable is a vector, the shape of the DataFrame is - shape ( draws X chains, len(vector)) + ( draws X chains, len(vector)) * If the variable is a matrix, the shape of the DataFrame is - shape ( draws X chains, size(dim 1) X size(dim 2) ) + ( draws X chains, size(dim 1) X size(dim 2) ) * If the variable is an array with N dimensions, the shape of the - DataFrame is shape ( draws X chains, size(dim 1) X ... X size(dim N)) + DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N)) :param name: variable name """