Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions cmdstanpy/stanfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def draws_as_dataframe(
the output, i.e., the sampler was run with ``save_warmup=True``,
then the warmup draws are included. Default value is ``False``.
"""
pnames_base = [name.split('.')[0] for name in self.column_names]
pnames_base = [name.split('[')[0] for name in self.column_names]
if params is not None:
for param in params:
if not (param in self._column_names or param in pnames_base):
Expand All @@ -723,51 +723,47 @@ def draws_as_dataframe(
mask = []
params = set(params)
for name in self.column_names:
if any(item in params for item in (name, name.split('.')[0])):
if any(item in params for item in (name, name.split('[')[0])):
mask.append(name)
return self._draws_as_df[mask]

def stan_variable(self, name: str) -> np.ndarray:
def stan_variable(self, name: str) -> pd.DataFrame:
"""
Return a new ndarray which contains the set of post-warmup draws
Return a new DataFrame which contains the set of post-warmup draws
for the named Stan program variable. Flattens the chains.
Underlyingly draws are in chain order, i.e., for a sample
consisting of N chains of M draws each, the first M array
elements are from chain 1, the next M are from chain 2,
and the last M elements are from chain N.

* If the variable is a scalar variable, this returns a 1-d array,
length(draws X chains).
* If the variable is a vector, this is a 2-d array,
shape ( draws X chains, len(vector))
* If the variable is a matrix, this is a 3-d array,
shape ( draws X chains, matrix nrows, matrix ncols ).
* If the variable is an array with N dimensions, this is an N+1-d array,
shape ( draws X chains, size(dim 1), ... size(dim N)).
* If the variable is a scalar variable, the shape of the DataFrame is
( draws X chains, 1).
* If the variable is a vector, the shape of the DataFrame is
( draws X chains, len(vector))
* If the variable is a matrix, the shape of the DataFrame is
( draws X chains, size(dim 1) X size(dim 2) )
* If the variable is an array with N dimensions, the shape of the
DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N))

:param name: variable name
"""
if name not in self._stan_variable_dims:
raise ValueError('unknown name: {}'.format(name))
self._assemble_draws()
dim0 = self.num_draws * self.runset.chains
dims = self._stan_variable_dims[name]
if dims == 1:
idx = self.column_names.index(name)
return self._draws[self._draws_warmup :, :, idx].reshape(
(dim0,), order='A'
)
else:
idxs = [
x[0]
for x in enumerate(self.column_names)
if x[1].startswith(name + '.')
]
var_dims = [dim0]
var_dims.extend(dims)
return self._draws[
self._draws_warmup :, :, idxs[0] : idxs[-1] + 1
].reshape(tuple(var_dims), order='A')
dims = np.prod(self._stan_variable_dims[name])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good use of prod here

pattern = r'^{}(\[[\d,]+\])?$'.format(name)
names, idxs = [], []
for i, column_name in enumerate(self.column_names):
if re.search(pattern, column_name):
names.append(column_name)
idxs.append(i)
return pd.DataFrame(
self._draws[
self._draws_warmup:, :, idxs
].reshape((dim0, dims), order='A'),
columns=names
)

def stan_variables(self) -> Dict:
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should we do for this function? Or maybe we could basically copy the stan_variable, but just change it so it works with multiple names. I don't mind of duplicated code, but I would like to create dataframe only once.

Would it make sense to use concat? And suggest users to learn filter

https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.filter.html

cc @mitzimorris

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could return the whole DataFrame as default and a DataFrame containing only the requested variables if names are given. But I guess that would make the stan_variable function kind of redundant.

Copy link
Copy Markdown
Contributor

@ahartikainen ahartikainen Sep 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, we could make it so that stan_variable would call stan_variables.

What is the output in CmdStanR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know because I don't use R but I'll have a look at it later. Does stan_variables have to be a function? How about a cached property and stan_variable just uses filter on the columns?

Copy link
Copy Markdown
Contributor Author

@LukasNeugebauer LukasNeugebauer Sep 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a look at the CmdStanR documentation and the equivalent seems to be CmdStanMCMC::draws() which returns a iterations X chains X variables array or a iterations X chains array when called with a variable name as argument.

Copy link
Copy Markdown
Member

@mitzimorris mitzimorris Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's useful about stan_variables and corresponding stan_variable_dims is that the returned dict gives you the names of all Stan variables present in the output - something that combines these two things would be good.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe we could still return a dict of dataframes in stan_variables, that sounds fine by me.

Expand Down
18 changes: 13 additions & 5 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,19 @@ def scan_column_names(fd: TextIO, config_dict: Dict, lineno: int) -> int:
line = fd.readline().strip()
lineno += 1
names = line.split(',')
config_dict['column_names'] = tuple(names)
config_dict['column_names'] = tuple(_rename_columns(names))
config_dict['num_params'] = len(names) - 1
return lineno


def _rename_columns(names: List) -> List:
names = [
re.sub(r',([\d,]+)$', r'[\1]', column.replace('.', ','))
for column in names
]
return names


def parse_var_dims(names: Tuple[str, ...]) -> Dict:
"""
Use Stan CSV file column names to get variable names, dimensions.
Expand All @@ -653,14 +661,14 @@ def parse_var_dims(names: Tuple[str, ...]) -> Dict:
while idx < len(names):
if names[idx].endswith('__'):
pass
elif '.' not in names[idx]:
elif '[' not in names[idx]:
vars_dict[names[idx]] = 1
else:
vs = names[idx].split('.')
if idx < len(names) - 1 and names[idx + 1].split('.')[0] == vs[0]:
vs = names[idx].split('[')
if idx < len(names) - 1 and names[idx + 1].split('[')[0] == vs[0]:
idx += 1
continue
dims = [int(vs[x]) for x in range(1, len(vs))]
dims = [int(x) for x in vs[1][:-1].split(',')]
vars_dict[vs[0]] = tuple(dims)
idx += 1
return vars_dict
Expand Down
40 changes: 20 additions & 20 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def test_gen_quantities_csv_files(self):
csv_file = bern_gqs.runset.csv_files[i]
self.assertTrue(os.path.exists(csv_file))
column_names = [
'y_rep.1',
'y_rep.2',
'y_rep.3',
'y_rep.4',
'y_rep.5',
'y_rep.6',
'y_rep.7',
'y_rep.8',
'y_rep.9',
'y_rep.10',
'y_rep[1]',
'y_rep[2]',
'y_rep[3]',
'y_rep[4]',
'y_rep[5]',
'y_rep[6]',
'y_rep[7]',
'y_rep[8]',
'y_rep[9]',
'y_rep[10]',
]
self.assertEqual(bern_gqs.column_names, tuple(column_names))
self.assertEqual(
Expand Down Expand Up @@ -104,16 +104,16 @@ def test_gen_quanties_mcmc_sample(self):
csv_file = bern_gqs.runset.csv_files[i]
self.assertTrue(os.path.exists(csv_file))
column_names = [
'y_rep.1',
'y_rep.2',
'y_rep.3',
'y_rep.4',
'y_rep.5',
'y_rep.6',
'y_rep.7',
'y_rep.8',
'y_rep.9',
'y_rep.10',
'y_rep[1]',
'y_rep[2]',
'y_rep[3]',
'y_rep[4]',
'y_rep[5]',
'y_rep[6]',
'y_rep[7]',
'y_rep[8]',
'y_rep[9]',
'y_rep[10]',
]
self.assertEqual(bern_gqs.column_names, tuple(column_names))
self.assertEqual(
Expand Down
Loading