Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to store stan.sampling directly to enable diagnostic methods #87

Closed
steveyang90 opened this issue Apr 17, 2020 · 3 comments · Fixed by #94
Closed

Refactor to store stan.sampling directly to enable diagnostic methods #87

steveyang90 opened this issue Apr 17, 2020 · 3 comments · Fixed by #94
Assignees
Labels
new idea / feature request New feature or request refactor Issues to remove tech debt or improve design

Comments

@steveyang90
Copy link
Collaborator

Currently we store stan.sampling.extract() as self.posterior_samples. However, we want to be able to retrieve chain level information from stan.sampling directly to enable diagnostic methods.

However, we can't currently retrieve stan.sampling directly since we never store the attribute.

There are two proposed solutions:

  1. Store stan.sampling at the end of fit, and only call stan.sampling.extract() for downstream methods such as predict, plot, diagnostic.
  2. Store self.posterior_samples as-is, and additionally store stan.sampling.to_dataframe() to something like self.posterior_samples_chain.

The first may require refactoring of other methods when we don't use stan.sampling for example if we fit using VI, MAP, or Pyro.

The second requires storing double the information that we would otherwise. An alternative approach to the second method is to parse the dataframe to the same state that our current self.posterior_samples is in, but poses a challenge because matrix samples are stored as a single column in a dataframe.

Another alternative to the second method is to store only the chain info from the dataframe, but we'd have to guarantee order preservation between the dataframe and arrays in stan.sampling.extract()

@steveyang90 steveyang90 added new idea / feature request New feature or request refactor Issues to remove tech debt or improve design labels Apr 17, 2020
@wangzhishi
Copy link
Contributor

wangzhishi commented Apr 22, 2020

I agree with @steveyang90 that this is only applied to .sampling. Option 1 would trigger too many changes.

First, to have the samples with the drawing order and chain info preserved, we have to turn off permute=True, but it comes with a problem that

  • .extract(permuted=True) returns an ordered dictionary
  • .extract(permuted=False) returns a ndarray with 3 dims: iteration x chains x params

I found that .extract(pars = ['param_1', ..., ], permuted=False) will return an ordered dictionary, however, each keyed item has an extra dimension (corresponding to the chain number), compared to the return of .exract(permuted=True).

So, I'm thinking the following concrete plan for mcmc method, where fit is compiled_stan_file in our case

stan_extract = fit.extract(pars = fit._get_param_names(), permuted=False)
for idx, (key, val) in enumerate(stan_extract.items()):
    if len(fit._get_param_dims()[idx]) == 0:
        stan_extract[key] = val.flatten(order='F')## here `order` is important to make chains flattened one by one
    else:
        stan_extract[key] = val.reshape(-1, val.shape[-1], order='F') 

After this, we got stan_extract which has exactly the same structure as from .extract(permuted=True), but the sample order is preserved. Say, 4 chains, 500 samples each chain, we will have 2000 samples with order [500 in chain1, ..., 500 in chain4] and insider each chain the draw order is also preserved.

Then we can use this ordered samples for diagnostics viz (ofc, needs a bit processing to cut the samples into chains).

This seems to require minimal change in our code base.

@steveyang90
Copy link
Collaborator Author

I think this makes sense @wangzhishi

What is if len(fit._get_param_dims()[idx]) == 0 checking for?

Also, you might find np.transpose() useful here

@wangzhishi
Copy link
Contributor

wangzhishi commented Apr 22, 2020

for example, scalar parameter (with dim []) and vector parameter (size 8, with dim [8]) samples has shape (500, 4) and (500, 4, 8) for 4 chains. My proposal is to collapse the shape into (2000,) and (2000, 8), which are consistent with the return of .extract(permuted=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new idea / feature request New feature or request refactor Issues to remove tech debt or improve design
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants