New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement AIR model #78
Comments
Agreed. I expect that it will be necessary to use I see work on this is underway, which is great! One thing I'm not clear about is how this will interact with sequence models, where the length of the sequence is stochastic, when writing the model to operate on entire mini-batches. As a simple example, consider this: def local_model(batch):
n = batch.shape(0)
# random choice for all data points
x = pyro.sample(Normal(zeros([n,1]), ones([n,1])))
# flip coins to decide which data points make second random choice
p = pyro.sample(Bernoulli(0.5 * ones([n,1])))
# make second random choice for a subset of the mini batch
m = torch.sum(p)
y = pyro.sample(Normal(zeros([m,1]), ones([m,1])))
# combine first and second choices
z = combine(x, y, p)
pyro.observe(dist, z)
map_data(data, local_model) The interesting feature of this is that the number of choices made per data point is not fixed. (This is similar but not identical to the case where the inputs are of differing lengths. Related #34, #67.) Here's the bit I'm unclear about... I'm guessing that the Rao-Blackwellized Will models of this form will be supported by the planned implementation, or is there a different way of expressing the model to make it fit? (Other than by using I suppose a direct approach to solving this would be to pass the information about which subset of data points a choice is been made for to Eventually, I think we'd ideally like to write local models for a single data point and have the back end figure out the mini batch version, but this is tricky of course. (Probably not news, but here is some related work on the variable length input case: 1, 2.) |
i discussed some related ideas with @eb8680 yesterday. if i'm understanding correctly this is about the vectorized version of map_data? (i.e. the webppl style version that maps the observation function over each data point in the batch should be ok?) we'll document this more fully elsewhere, but my current thinking is that the functional map version should be the basic version that defines the correct behavior; vectorization is then an optimization that the implementation can try to do. the thought is that this optimization is achievable by overloading the tensor library to have a batch dim, while marking some tensor ops as "unsafe for vectorization"; if an observation function is trying to be vectorized but hits an unsafe op, it will bail out and fall back on the independent map version. does this make sense? |
Yes.
Yes, OK in the sense that we can write the model down and inference will do the right thing. The question I had in my mind, but didn't make explicit, was whether the non mini batch version would be OK in terms of the performance goals of the project. I'll measure the performance of both implementations at some point.
Yep, I think I get the general idea. |
right: probably not. but doing it in two steps seems cleaner anyhow. |
I'm working on this over on this branch.
An initial test suggests that the vectorized version is about 5 times faster than webppl style. This probably underestimates the difference we'll see in the end, since the particular vectorized implementation I measured this on was doing a lot of wasted computation that we could likely avoid. (Each sequence in the batch was padded with extra steps so that all sequences had the same length.) |
@eb8680 The main thing is to use the independence info from |
@null-a do you have a way to get a quick check of how the model is performing now (with neural baseline but sans mapdata independence)? |
@ngoodman No. I don't have baselines implemented yet, and once I do I don't know how to check performance other than by running it, which isn't quick. @martinjankowiak I'm still unclear about how an RNN can be used to output a baseline for each choice in a sequence. Any chance you could provide a rough sketch of how this would look? |
@null-a with the code that's currently in dev, i don't think you can do that. can you provide a pseudo-code snippet that sketches out what you want? then i can see what set of changes would be required. doing this more or less elegantly may require changes in the interface |
@martinjankowiak I think the idea would be to have an extra RNN (in addition to the inference net) that runs along the sequence, which is used to output a baseline value at each choice. After the choice, the sampled value would be used to produce a new hidden state. So focussing on a single choice in the sequence for a single data point, and ignoring the inference net, we might have something like:
I guess as long as we can package the whole baseline net up into a single torch module (so that all of its params are updated) then we can probably make it work with the current interface. That seems like it might be do-able, so I'll make an attempt at some point. Thanks. |
Progress update: It looks like my implementation is really slow at present. I estimate it will take around a year to get to reasonable inferences out of the guide (optimizing on the CPU), and much longer to run optimization for as long as they did for the paper. (Assuming I'm interpreting the results in the paper correctly.) So, I'll need a bunch of tricks to speed this up, and I guess that one of those will end up been the use of vectorized |
hmm.... we didn't have vectorization in webppl, so how did we get acceptable performance there (or did we never get that far)? |
The latter. |
Hi guys, Regarding baselines. "I.5 Supervised learning As such, they basically learn baselines in a supervised way (apparently). I shot them an email to follow up, but wanted to highlight this here. Linking to #126 . |
My understanding is that this is used on the 3D scene example and not the multi-mnist example I'm working on. (I'm mentioning this only to point out that I don't think this stands in the way of me reproducing the result I'm shooting for, and not to take anything away from the idea, which is interesting.) |
woohoo!
fascinating. if you only give it two timesteps, then it uses the first appropriately? and you're sure you are counting in the right direction? ;)
we (collectively) should make a plan for getting a clean version worked out and into dev, since i think we'll need it for release. |
interesting! could it be as simple as decreasing the prior probability of recursion, to better encourage not using extra steps? (this still doesn't explain why the earlier result was punting on the first step, which is very odd.) |
Yeah, I already made one change in that direction for the reason you suggest, but I could go further. Related to this, yesterday I learned that for the paper they "annealed the success probability from a value close to 1 to either |
useful blog post -- nice find! annealing the success probability doesn't seem crazy (if a bit hacky). it would be nice if this were straightforward to do in pyro (it's the kind of tinkering with learning that we want to make accessible). i think it might be as simple as making success prob an arg to model, which then becomes an arg to kl.step, and changing it as we like over learning? if it's more complex than that, there's no need to implement now, but it'd be nice to think it through sometime. |
Yeah, I think it's as simple as that. |
Closed by #259 |
The attend-infer-repeat (AIR) model is a great example of stochastic recursion in deep probabilistic models. It is going to be one of the anchor examples for the first pyro release.
We need to implement it, implement any extensions that training relies on, and replicate (some) results from the paper.
Getting acceptably low variance for the gradient estimator term from the discrete RV will likely take some amount of rao-blackwellization and data dependent baselines. @martinjankowiak has started or planned these.
The text was updated successfully, but these errors were encountered: