-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Neural network layers #16
Conversation
I believe `path_length` is still in use for historical reasons; however, it makes more sense to reason in terms of number of integration steps while it simplifies the cost (no dynamic computation of the number of integration steps and casting to int). I thus replaced every mention of `path length` in the HMC proposal and program with `num_integration_steps`.
The boundary between programs (sampling algorithms) and runtimes (executors) was not very clear. I remove any dependence on the model from the program and responsibilities are now clear. Most of the initialization has been transferred to the runtime. I also improved the performance of the creation of initial states.
@rlouf I'm not sure if this might help a bit, but would a blog post I wrote on shapes be helpful to you? No pressure to read it though. Just a thought, no pressure. |
@ericmjl Thank you for the link, I did read your post before implementing distributions. It was really helpful to dive into TFP’s shape system! Is there anything in particular you think I might have missed that could help me? |
@rlouf thank you for the kind words! I think (but I'm not 100% sure) maybe working backwards from the desired semantics might be helpful? Personally, when I think of Gaussian priors on a neural network's weights, I tend to think of them as being the "same" prior (e.g. I think I might still be unclear, so let me attempt an example that has contrasts in there. Given the following NN: @mcx.model
def mnist(image):
nn ~ ml.Serial(
dense(400, Normal(0, 1)),
dense(400, Normal(3, 4)),
dense(10, Normal(-2, 7)),
softmax(),
)
p = nn(image)
cat ~ Categorical(p)
return cat I would read it as:
I think the suggestion I have here matches to your 2nd option exactly:
You don't have to accept the exact suggestion, but maybe implementing it one way first and then trying it out might illuminate whether it's good or not? In reimplementing an RNN, I did the layers in an opinionated, "my-way" fashion first, then realized it'd be easier and more compatible to just go |
Interesting feedback, thank you for taking the time to explain! The NN API is indeed a bit tricky to get right the first time. I am currently leaning towards what you're proposing. Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the This way it is also compatible with crazy specs, like a different variance for each layer weight. |
Yes, I would! It sounds like a sensible default to have. |
Thank you for your insights! It feels good to have someone else's opinion. Was your RNN project Bayesian? If so, is the code available somewhere? |
The RNN wasn't Bayesian, and it was mostly a re-implementation of the original, but done in JAX. Given that it's written stax-style, I'm sure it shouldn't be too hard to extend it to mcx 😄. You can find the repo here, and we have a mini-writeup available too. |
48e21e4
to
025b5f1
Compare
To keep a simple API when building Bayesian Neural Network we don't want to have to specify the batching shape of the prior distribution so that it matches the layer size. Therefore we add a helper function that re-broadcasts a distribution to a destination shape (here the layer size) so it can be used in the neural network internals.
ad89b00
to
4e2e2fa
Compare
8d48e5a
to
9b0c1e7
Compare
2f93d4d
to
e906ee0
Compare
f8f3e6b
to
965f6dd
Compare
Closing for now; the relevant info is in the discussions. |
I open this PR to start thinking about the design of bayesian neural network layers. The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights.
The goal is to able to take any model expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.
Of course, we should be able to construct hierarchical models by adding hyperpriors on the priors’ parameters.
Layers are distributions over functions; let us see what if could look like on a naive MNIST example:
The above snippet is naive in the sense that the way
Normal(0, 1)
is related to each weight in the layer is not very clear. We need to specify broadcasting rules for the bayesian layers.We should be able to easily define hierarchical models:
Forward sampling
Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.
We could define a
sample
method that draws a realization of each layer and performs a forward pass with the drawn weights.where
weights
is a tuple that contains all the weights's realized value. This would keep a similar API to the distributions' with the addedoutput
return value that reflects the fact that we are sampling a function.Another option is
which feels less magical.
Log-probability density function
Note: the
__call__
method of the layers calls thepure_fn
method which is jit-able. Not sure it is necessary to call it directly here.