-
-
Notifications
You must be signed in to change notification settings - Fork 111
add a draft for traces #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
One thing I would love to see in PyMC4 is a more clear notion of "what is a trace?" It's problematic -- especially when incorporating a model and inference into a larger data flow -- that PyMC3 has at least two kinds of trace: |
pymc4/inference/utils.py
Outdated
| arviz.data.inference_data.InferenceData | ||
| """ | ||
| import arviz as az | ||
| az_dict = {k: np.swapaxes(v.numpy(), 1, 0) for k, v in pm4_trace.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could potentially be problematic if you have > 1d batch shape: arviz-devs/arviz#456 (comment)
But I guess in this case it is fine since we batch the log_prob ourself and the num_chain is only 1d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That already works: <xarray.DataArray 'hierarchical_model/beta' (chain: 50, draw: 200, hierarchical_model/beta_dim_0: 85)>
|
@rpgoldman I think we should just output arviz trace objects everywhere and use that standard. |
Sounds good to me! A good bit of pain in trying to speed up the posterior predictive sampling was not knowing what sort of arguments could be passed in as the trace (and the fact that the tests used some odd ones, like a list made up of the |
ArviZ technically doesn't have a trace object. I would say you want to output |
| """ | ||
| Tensorflow to Arviz trace convertor. | ||
| Convert a PyMC4 trace as returned by sample() to an ArviZ trace object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would change this to to an az.InferenceData object or Arviz InferenceData object
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, this is a bit too specific. Let me try to see if I can add a helper class in TFP so that the output is a bit more standardized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add types (including return type)? This information seems to be available in the docstrings...
| Returns | ||
| ------- | ||
| arviz.data.inference_data.InferenceData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be shortened to az.InferenceData for end users
| """ | ||
| """PyMC4 continuous random variables for tensorflow.""" | ||
| import tensorflow_probability as tfp | ||
| from pymc4.distributions import abstract |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
being picky here, should we use relative imports for inside the library? Not part of this PR but just wanted to ask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use doctest to test code snippets in documentation. Doctest is complaining if import is relative sometimes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had that problem with pytest, too -- complaints about relative imports -- and it turned out for me it was because I was running the tests (this is for PyMC3) inside the source directory. When I ran it "above" the directory (i.e., from pymc3 instead of pymc3/pymc3/) the complaints about relative imports went away.
I think the advantage of relative imports is that they don't risk an import cycle as much as absolute ones do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the advantage of relative imports is that they don't risk an import cycle as much as absolute ones do.
I believe this part is true, that relative imports risk less that another package of same name will be imported from sys.path versus the adjacent module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My personal feeling is that it is hard to run into a problem you describe, you would not import pymc4 in the first place. But for our use case we can test code snippets in docs without mess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, for this PR just ignore me there's more important things :) thanks @ferrine
| from .. import Model, flow | ||
|
|
||
|
|
||
| def initialize_state(model: Model, observed: Optional[dict] = None) -> flow.SamplingState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe refine the type declaration here? I.e., change to Optional[Dict[x, y]] for some x and y? Or declare a type for this kind of dictionary, e.g., ObsDict = Dict[x, y]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing special about observed except keys are strings. The API is not that narrow at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that would be Dict[str, Any] then.
That way we know the keys are names, rather than variables (and so does mypy).
| """ | ||
| Tensorflow to Arviz trace convertor. | ||
| Convert a PyMC4 trace as returned by sample() to an ArviZ trace object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add types (including return type)? This information seems to be available in the docstrings...
Adding a draft for traces. Previously, traces were returned in a raw format. This PR should solve this issue and ease plotting