-
-
Notifications
You must be signed in to change notification settings - Fork 81
Stan variable #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stan variable #287
Changes from all commits
3f78137
0c48ec5
98a6457
6406aee
c2c5de5
beef6b0
eddf81f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -702,7 +702,7 @@ def draws_as_dataframe( | |
| the output, i.e., the sampler was run with ``save_warmup=True``, | ||
| then the warmup draws are included. Default value is ``False``. | ||
| """ | ||
| pnames_base = [name.split('.')[0] for name in self.column_names] | ||
| pnames_base = [name.split('[')[0] for name in self.column_names] | ||
| if params is not None: | ||
| for param in params: | ||
| if not (param in self._column_names or param in pnames_base): | ||
|
|
@@ -723,51 +723,47 @@ def draws_as_dataframe( | |
| mask = [] | ||
| params = set(params) | ||
| for name in self.column_names: | ||
| if any(item in params for item in (name, name.split('.')[0])): | ||
| if any(item in params for item in (name, name.split('[')[0])): | ||
| mask.append(name) | ||
| return self._draws_as_df[mask] | ||
|
|
||
| def stan_variable(self, name: str) -> np.ndarray: | ||
| def stan_variable(self, name: str) -> pd.DataFrame: | ||
| """ | ||
| Return a new ndarray which contains the set of post-warmup draws | ||
| Return a new DataFrame which contains the set of post-warmup draws | ||
| for the named Stan program variable. Flattens the chains. | ||
| Underlyingly draws are in chain order, i.e., for a sample | ||
| consisting of N chains of M draws each, the first M array | ||
| elements are from chain 1, the next M are from chain 2, | ||
| and the last M elements are from chain N. | ||
|
|
||
| * If the variable is a scalar variable, this returns a 1-d array, | ||
| length(draws X chains). | ||
| * If the variable is a vector, this is a 2-d array, | ||
| shape ( draws X chains, len(vector)) | ||
| * If the variable is a matrix, this is a 3-d array, | ||
| shape ( draws X chains, matrix nrows, matrix ncols ). | ||
| * If the variable is an array with N dimensions, this is an N+1-d array, | ||
| shape ( draws X chains, size(dim 1), ... size(dim N)). | ||
| * If the variable is a scalar variable, the shape of the DataFrame is | ||
| ( draws X chains, 1). | ||
| * If the variable is a vector, the shape of the DataFrame is | ||
| ( draws X chains, len(vector)) | ||
| * If the variable is a matrix, the shape of the DataFrame is | ||
| ( draws X chains, size(dim 1) X size(dim 2) ) | ||
| * If the variable is an array with N dimensions, the shape of the | ||
| DataFrame is ( draws X chains, size(dim 1) X ... X size(dim N)) | ||
|
|
||
| :param name: variable name | ||
| """ | ||
| if name not in self._stan_variable_dims: | ||
| raise ValueError('unknown name: {}'.format(name)) | ||
| self._assemble_draws() | ||
| dim0 = self.num_draws * self.runset.chains | ||
| dims = self._stan_variable_dims[name] | ||
| if dims == 1: | ||
| idx = self.column_names.index(name) | ||
| return self._draws[self._draws_warmup :, :, idx].reshape( | ||
| (dim0,), order='A' | ||
| ) | ||
| else: | ||
| idxs = [ | ||
| x[0] | ||
| for x in enumerate(self.column_names) | ||
| if x[1].startswith(name + '.') | ||
| ] | ||
| var_dims = [dim0] | ||
| var_dims.extend(dims) | ||
| return self._draws[ | ||
| self._draws_warmup :, :, idxs[0] : idxs[-1] + 1 | ||
| ].reshape(tuple(var_dims), order='A') | ||
| dims = np.prod(self._stan_variable_dims[name]) | ||
| pattern = r'^{}(\[[\d,]+\])?$'.format(name) | ||
| names, idxs = [], [] | ||
| for i, column_name in enumerate(self.column_names): | ||
| if re.search(pattern, column_name): | ||
| names.append(column_name) | ||
| idxs.append(i) | ||
| return pd.DataFrame( | ||
| self._draws[ | ||
| self._draws_warmup:, :, idxs | ||
| ].reshape((dim0, dims), order='A'), | ||
| columns=names | ||
| ) | ||
|
|
||
| def stan_variables(self) -> Dict: | ||
| """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should we do for this function? Or maybe we could basically copy the Would it make sense to use concat? And suggest users to learn filter https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.filter.html cc @mitzimorris
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could return the whole DataFrame as default and a DataFrame containing only the requested variables if names are given. But I guess that would make the stan_variable function kind of redundant.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, we could make it so that stan_variable would call stan_variables. What is the output in CmdStanR?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know because I don't use R but I'll have a look at it later. Does stan_variables have to be a function? How about a cached property and stan_variable just uses filter on the columns?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just had a look at the CmdStanR documentation and the equivalent seems to be CmdStanMCMC::draws() which returns a iterations X chains X variables array or a iterations X chains array when called with a variable name as argument.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's useful about
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So maybe we could still return a dict of dataframes in |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good use of prod here