Replies: 5 comments 12 replies
-
Welcome :) I've never used Github discussions before but this looks like the perfect place to discuss of things that are not directly related to the code. Indeed tests and examples are very useful at this stage. Tests on distributions (especially regarding the shapes) are important, but also regression tests for the language and the API. Regarding examples, I started implementing examples from Statistical Rethinking and have my eyes on Bayesian methods for hackers as well.
Tests are a great way to start contributing. I also added a few other suggestions in this issue, in particular MCX could do with a |
Beta Was this translation helpful? Give feedback.
-
Sure that's a great idea! I usually write notes in markdown when I'm learning/thinking through stuff so if I can organize them together in any usable/coherent fashion that would easy to add as documentation. |
Beta Was this translation helpful? Give feedback.
-
Thank you for taking the time to share your notes! Since you're the
first one (that I know of) to read the internals, I'd be happy to hear
your thoughts about the general design and the naming conventions. What
made sense immediately? What didn't? Maybe we could start from your
notes and answers to those question to document the internals.
What the linear_regression function looks like in AST form is:
Oh that's cool! I haven't plotted this figure yet, but it would have
made things easier. That'd be great in the docs.
Finally, if none some cases arise, the ModelDefinitionParser.recursive_visit
method throws an TypeError and tells the user to report it as an issue. This
means that some syntax has not been anticipated and perhaps needs to be
incorporated into the code.
Yes, added or explicitly forbidden. We don't want the code to
inexplicably crash, so we prefer to raise a SyntaxError by default.
Note that this is where we are actually doing sampling from the priors of the
model, this is what is called 'prior predictive sampling' in Richard
Mcelreath's Statistical Rethinking book.
There is an important distinction to make between the two prior
distributions. MCX models are functions that can return a value; the
prior distribution of this value is the _prior predictive distribution_,
given by `sample`. Calling the model `model(rng_key, *args)` also gives
you samples from this distribution.
MCX models also implicitly define a (multivariate) probability
distribution; samples from this distribution are given by
`joint_sample`.
Reading your notes I feel that `sample` is an unfortunate name choice;
it is not obvious what it samples exactly. Would you agree?
Now what we have is a mcx model function which we can:
Yes it is important to note that MCX models will be model-first and not
distribution first, as `model()` returns a value and not a class
instance. This may change in the future as the latter would make the
internals slightly simpler.
But keep this in mind, models being hybrid objects can be confusing
sometimes.
b) Or pass this to a sampler for doing inference -> not really sure where/how
this happens right now on the compiler-refactor branch.
`samples = mcx.sampler(rng_key, model, (x_data,), {'observations': y_data}, kernel).run()`
I add to check the code, which must mean the API is not quite there yet. I sometimes wonder
if we should be more explicit:
```python
samples = mcx.sampler(
rng_key,
model.condition(X=x_data, observations= y_data),
kernel
).run()
```
|
Beta Was this translation helpful? Give feedback.
-
and then stepping through everything with a debugger. As I mentioned I kept
notes in a markdown, which i'd be happy to somehow add some part of to the docs
if you think it would be useful, it's a bit pedantic and skips over some of the
finer details of how things are handled so I'm not sure how useful it is.
That was exhaustive! Made shorter and keeping the big picture ideas this
could be very useful for contributors or anyone curious about how this
all works.
And the mcx.core.sample_joint function is pretty much doing
the same thing as the mcx.core.sample function, it's just returning a
dictionary with the parameters (i.e. random variables) and samples from those
priors? Is that correct?
I think I anticipated that question in my previous reply. Yes that's
correct. The reason why they don't share so much code is the behavior
when a MCX model is called within the current model.
Another question - In the sample method of the model class, and in the call
method which is used when you call, for example, linear_regression(rng_key=
rng_key, x=some_data), I'm not seeing exactly where the sample size is defined?
You _should_ be able to specify a `sample_shape` argument in the `sample`
method to follow the API of `mcx.Distribution`. I simply forgot!
However, you won't be able to specify the sample size for
`linear_regression(rng_key, x)`. However you can do:
```python
keys = jax.random.split(rng_key, num_samples)
samples = jax.vmap(linear_regression, in_axis=(0, None))(keys, x)
```
You can see the `sample` method as being a shortcut for the more
barebones call.
was this removed/left out for a particular reason?
No, I forgot!
Hope this isn't too many questions!
Nope :)
|
Beta Was this translation helpful? Give feedback.
-
Thanks for the feedback! I am glad this is understandable; very few
people are comfortable with parsing/modifying the AST so I was a bit
afraid this would be confusing. Still some effort needed on the naming
side (cf our discussion on prior sampling).
I used showast in a jupyter lab session. I'm quite a visual learner so it
definitely helps me to visualize these trees, I could envision some boxes
around different parts of this tree indicating where things are handled
internally by mcx but maybe its a bit overkill!
I'm a visual learner too and this is super helpful! It can be a great
debugging tool for the internals as well; reading through 200 lines of
AST dump, or print a graph traversal is not particularly pleasant.
Sorry I'm still not sure I get it - what do you exactly mean when you say "the
two prior distributions" and "functions that can return a value; the prior
distribution of this value is the prior predictive distribution? My
understanding is that the prior predictive distribution would just simply be a
joint distribution of the priors which you can use to simulate predictions - or
am I misunderstanding?
The model implicitly defines a joint distribution on the random
variables (say `a` and `b`). `sample_joint` returns prior samples
from this distribution where each sample is a dictionary of values
`{"a": val_a, "b": val_b}`.
The returned value is considered by definition to be the "predicted" value
of the model. The distribution of this value given the random variables
are distibuted according to their prior distribution is what I call the
prior predictive distribution.
Yes maybe my misunderstanding comes from the naming, perhaps it would be better
named otherwise, but I'm not sure to what at the moment.
`predictive_sample` ?
This is a great point, I think making the prior predictive and inference stuff
as straight forward as possible is a great goal, having used Numpyro a bit, I
found the use of their Predictive class for doing these two things to be quite
confusing, so I think having this easy and understandable would be a big plus.
Currently to get samples from the prior predictive distribution you
would call `mcx.predict(rng_key, model, *args)`.
To get posterior predictive distribution you first need to evaluate
the model to use the posterior distribution of the random variables
`evaluated_model = mcx.evaluate(model, trace)`. You can then use the
same function `mcx.predict(rng_key, evaluated_model, *args)` to get
sample from the predictive distribution.
I spent a lot of time thinking about this and I think it is the only way
to have a coherent API. `mcx.predict` returns samples from the
predictive distibution. Between prior and posterior predictive sampling
the difference is the model you sample from: one where random variables
are distributed according to their prior distribution for the former,
one where they are distributed according to their posterior distribution
for the latter.
Another option would be to use multiple dispatch so `mcx.predict`
returns prior predictive samples if `rng_key`, `model`, and the args are
specified, posterior predictive samples if you specify the trace as
well. But it feels a bit magical. Yet another one is to split between
`mcx.prior_predict` and `mcx.posterior_predict`.
In terms of how to call it, I think what you propose could be good, I like the
use of condition as a term, but would there be a case when you would be passing
a model to a sampler when you would not be conditioning it on data (if the
whole prior predictive stuff is a method that you can call on the model
object)?
That can happen in theory, for instance when people try samplers on
Neal's funnel.
Good then I'll keep asking as I dig in more, very cool stuff thus far! 🎉
Great! I've been working on this pretty much in isolation so questions
and constructive criticism are much appreciated.
|
Beta Was this translation helpful? Give feedback.
-
Hi, I'd be interest in contributing but am fairly new to PPLs and to Bayesian stats (i.e. have read statistical rethinking and worked through a number of the end of chapter problems and have used pymc3 and numpyro a bit for personal/work things), so I'm a bit skeptical about what I can contribute but would see contributing as a great way to dive deeper into these topics and contribute to some OSS. The only issues I see marked as "good first issue" are the call for use cases and examples and the one with tests. I'd be happy to start with the tests.
I'd maybe have a few more questions:
I've read through the
/design_notes/mcx_design.md
but I'm wondering where I could start with reading a bit more so as to be a bit more useful (e.g. I've never dealt with abstract syntax tree in python, I guess this would be something very fundamental to read up on) - as is often the case when learning new stuff its pretty easy to get overwhelmed with what I don't know yet 😄In the docs its mentioned that there is inspiration drawn from PyMC3, Tensorflow Probability, Numpyro etc., but I'm wondering - what exactly would be the value add for mcx, or rather what would it be adding/doing differently than all these other python PPLs? Is it simply, as stated in the
/design_notes/mcx_design.md
, that it is treating the whole model definition part as the program and the sampling procedure as the "compiler", or is there something I've missed?Is this the correct way/place to discuss such questions, I've actually never seen the Discussions feature in Github 🤷♂️
Cheers!
Beta Was this translation helpful? Give feedback.
All reactions