Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AIR model #78

Closed
ngoodman opened this issue Aug 22, 2017 · 24 comments
Closed

Implement AIR model #78

ngoodman opened this issue Aug 22, 2017 · 24 comments
Assignees
Milestone

Comments

@ngoodman
Copy link
Collaborator

The attend-infer-repeat (AIR) model is a great example of stochastic recursion in deep probabilistic models. It is going to be one of the anchor examples for the first pyro release.

We need to implement it, implement any extensions that training relies on, and replicate (some) results from the paper.

Getting acceptably low variance for the gradient estimator term from the discrete RV will likely take some amount of rao-blackwellization and data dependent baselines. @martinjankowiak has started or planned these.

@null-a
Copy link
Collaborator

null-a commented Aug 24, 2017

Getting acceptably low variance for the gradient estimator term from the discrete RV will likely take some amount of rao-blackwellization and data dependent baselines.

Agreed. I expect that it will be necessary to use map_data to weight LR terms by only those "costs" incurred by the associated data point.

I see work on this is underway, which is great! One thing I'm not clear about is how this will interact with sequence models, where the length of the sequence is stochastic, when writing the model to operate on entire mini-batches.

As a simple example, consider this:

def local_model(batch):
  n = batch.shape(0)

  # random choice for all data points
  x = pyro.sample(Normal(zeros([n,1]), ones([n,1])))

  # flip coins to decide which data points make second random choice
  p = pyro.sample(Bernoulli(0.5 * ones([n,1])))

  # make second random choice for a subset of the mini batch
  m = torch.sum(p)
  y = pyro.sample(Normal(zeros([m,1]), ones([m,1])))

  # combine first and second choices
  z = combine(x, y, p)

  pyro.observe(dist, z)

map_data(data, local_model)

The interesting feature of this is that the number of choices made per data point is not fixed. (This is similar but not identical to the case where the inputs are of differing lengths. Related #34, #67.)

Here's the bit I'm unclear about... I'm guessing that the Rao-Blackwellized map_data implementation will assume that a data point has a fixed position within the mini batch, in order to track the "cost" for each data point across multiple choices? However, that assumption doesn't hold for the example above -- since we only make the second choice for a subset of the data, a particular data point may appear in a different position across choices.

Will models of this form will be supported by the planned implementation, or is there a different way of expressing the model to make it fit? (Other than by using batch_size=1 and combining gradient estimates for a mini batch by hand.)

I suppose a direct approach to solving this would be to pass the information about which subset of data points a choice is been made for to sample, recovering the ability to track data points across choices. (Either as a separate argument, or encoded in the parameters in someway.) This is a bit fiddly though, so hopefully there's something better?

Eventually, I think we'd ideally like to write local models for a single data point and have the back end figure out the mini batch version, but this is tricky of course. (Probably not news, but here is some related work on the variable length input case: 1, 2.)

@ngoodman
Copy link
Collaborator Author

i discussed some related ideas with @eb8680 yesterday. if i'm understanding correctly this is about the vectorized version of map_data? (i.e. the webppl style version that maps the observation function over each data point in the batch should be ok?)

we'll document this more fully elsewhere, but my current thinking is that the functional map version should be the basic version that defines the correct behavior; vectorization is then an optimization that the implementation can try to do. the thought is that this optimization is achievable by overloading the tensor library to have a batch dim, while marking some tensor ops as "unsafe for vectorization"; if an observation function is trying to be vectorized but hits an unsafe op, it will bail out and fall back on the independent map version. does this make sense?

@null-a
Copy link
Collaborator

null-a commented Aug 25, 2017

if i'm understanding correctly this is about the vectorized version of map_data?

Yes.

the webppl style version that maps the observation function over each data point in the batch should be ok?

Yes, OK in the sense that we can write the model down and inference will do the right thing. The question I had in my mind, but didn't make explicit, was whether the non mini batch version would be OK in terms of the performance goals of the project. I'll measure the performance of both implementations at some point.

does this make sense?

Yep, I think I get the general idea.

@ngoodman
Copy link
Collaborator Author

The question I had in my mind, but didn't make explicit, was whether the non mini batch version would be OK in terms of the performance goals of the project.

right: probably not. but doing it in two steps seems cleaner anyhow.

@null-a
Copy link
Collaborator

null-a commented Aug 29, 2017

I'm working on this over on this branch.

I'll measure the performance of both implementations at some point.

An initial test suggests that the vectorized version is about 5 times faster than webppl style. This probably underestimates the difference we'll see in the end, since the particular vectorized implementation I measured this on was doing a lot of wasted computation that we could likely avoid. (Each sequence in the batch was padded with extra steps so that all sequences had the same length.)

@eb8680
Copy link
Member

eb8680 commented Sep 12, 2017

@null-a now that #61 and #62 and #84 are merged are there any other Pyro features we need for this example?

@null-a
Copy link
Collaborator

null-a commented Sep 12, 2017

@eb8680 The main thing is to use the independence info from map_data when building the dependency graph. (Unless that found its way in without me noticing.)

@ngoodman
Copy link
Collaborator Author

@null-a do you have a way to get a quick check of how the model is performing now (with neural baseline but sans mapdata independence)?

@null-a
Copy link
Collaborator

null-a commented Sep 13, 2017

do you have a way to get a quick check of how the model is performing now (with neural baseline but sans mapdata independence)?

@ngoodman No. I don't have baselines implemented yet, and once I do I don't know how to check performance other than by running it, which isn't quick.

@martinjankowiak I'm still unclear about how an RNN can be used to output a baseline for each choice in a sequence. Any chance you could provide a rough sketch of how this would look?

@martinjankowiak
Copy link
Collaborator

@null-a with the code that's currently in dev, i don't think you can do that. can you provide a pseudo-code snippet that sketches out what you want? then i can see what set of changes would be required. doing this more or less elegantly may require changes in the interface

@null-a
Copy link
Collaborator

null-a commented Sep 13, 2017

@martinjankowiak I think the idea would be to have an extra RNN (in addition to the inference net) that runs along the sequence, which is used to output a baseline value at each choice. After the choice, the sampled value would be used to produce a new hidden state. So focussing on a single choice in the sequence for a single data point, and ignoring the inference net, we might have something like:

baseline = some_nn(rnn_hid_state)
x = sample('x', dist, baseline=baseline)
new_rnn_hid_state = rnn(rnn_hid_state, embed(x))

I guess as long as we can package the whole baseline net up into a single torch module (so that all of its params are updated) then we can probably make it work with the current interface. That seems like it might be do-able, so I'll make an attempt at some point. Thanks.

@null-a
Copy link
Collaborator

null-a commented Sep 15, 2017

Progress update: It looks like my implementation is really slow at present. I estimate it will take around a year to get to reasonable inferences out of the guide (optimizing on the CPU), and much longer to run optimization for as long as they did for the paper. (Assuming I'm interpreting the results in the paper correctly.)

So, I'll need a bunch of tricks to speed this up, and I guess that one of those will end up been the use of vectorized map_data. This will require Rao-Blackwellization/baselines for vectorized map_data, and perhaps parts of #34.

@ngoodman
Copy link
Collaborator Author

hmm.... we didn't have vectorization in webppl, so how did we get acceptable performance there (or did we never get that far)?

@null-a
Copy link
Collaborator

null-a commented Sep 15, 2017

how did we get acceptable performance there (or did we never get that far)?

The latter.

@karalets
Copy link
Collaborator

Hi guys,

Regarding baselines.
I wanted to throw in a chunk from the paper-appendix that maybe was missed here.

"I.5 Supervised learning
For the baselines trained in a supervised manner we use the ground truth scene variables
[z_1:N pres, z:N where, z1:N what] that underly the training scene images as labels and train a network of the same form as the inference network to maximize the conditional log likelihood of the ground truth scene variables given the image."

As such, they basically learn baselines in a supervised way (apparently).

I shot them an email to follow up, but wanted to highlight this here. Linking to #126 .

@null-a
Copy link
Collaborator

null-a commented Sep 19, 2017

I wanted to throw in a chunk from the paper-appendix that maybe was missed here.

My understanding is that this is used on the 3D scene example and not the multi-mnist example I'm working on. (I'm mentioning this only to point out that I don't think this stands in the way of me reproducing the result I'm shooting for, and not to take anything away from the idea, which is interesting.)

@null-a
Copy link
Collaborator

null-a commented Sep 25, 2017

Progress update:

It looks like my implementation is really slow at present.

I've now re-written the model in vectorized style. In order to optimize it, I've cobbled together an implementation of kl_qp that supports vectorized Rao-Blackwellization/baselines. This kl_qp is an ugly hack that works for this model but isn't very general. It has allowed me to get results out of this model, but it's nothing like the fully general thing that we need to implement eventually.

I've also switched to running this on a gpu, which together with vectorization makes things run a couple of orders of magnitude faster, making it usable.

Results so far: My goal is to replicate their first result on multi-mnist, and optimization of the pyro implementation seems to be working almost as hoped. I'm seeing similar progress on the elbo to that reported in the paper, the inference net is successfully picking out digits in the image, and reconstructions look reasonable.

The main snag is that rather than always avoiding use of the final time step (which is never necessary to explain the input) the guide is instead wasting the first time step, by using it to explain nothing. For example:

4 input images:
inputs

Reconstructions, with vizualization of (some of the) latents:
final-recon
(Step one in red, two in green, three in blue.)

I don't yet understand why this is happening, but I'm working on it. (Any thoughts on this are welcome, of course!)

@ngoodman
Copy link
Collaborator Author

I've also switched to running this on a gpu, which together with vectorization makes things run a couple of orders of magnitude faster, making it usable.

woohoo!

The main snag is that rather than always avoiding use of the final time step (which is never necessary to explain the input) the guide is instead wasting the first time step, by using it to explain nothing.

fascinating. if you only give it two timesteps, then it uses the first appropriately? and you're sure you are counting in the right direction? ;)

This kl_qp is an ugly hack that works for this model but isn't very general. It has allowed me to get results out of this model, but it's nothing like the fully general thing that we need to implement eventually.

we (collectively) should make a plan for getting a clean version worked out and into dev, since i think we'll need it for release.

@null-a
Copy link
Collaborator

null-a commented Sep 27, 2017

if you only give it two timesteps, then it uses the first appropriately?

Yeah, it appears so:

img-2-steps-recon

(though here the last step is still used unnecessarily when there's one digit.)

@ngoodman
Copy link
Collaborator Author

interesting! could it be as simple as decreasing the prior probability of recursion, to better encourage not using extra steps? (this still doesn't explain why the earlier result was punting on the first step, which is very odd.)

@null-a
Copy link
Collaborator

null-a commented Sep 28, 2017

could it be as simple as decreasing the prior probability of recursion, to better encourage not using extra steps?

Yeah, I already made one change in that direction for the reason you suggest, but I could go further.

Related to this, yesterday I learned that for the paper they "annealed the success probability from a value close to 1 to either 1e−5 or 1e−10 depending on the dataset over the course of 100k training iterations". (I'd rather not have to do that though.)

@ngoodman
Copy link
Collaborator Author

useful blog post -- nice find!

annealing the success probability doesn't seem crazy (if a bit hacky). it would be nice if this were straightforward to do in pyro (it's the kind of tinkering with learning that we want to make accessible). i think it might be as simple as making success prob an arg to model, which then becomes an arg to kl.step, and changing it as we like over learning? if it's more complex than that, there's no need to implement now, but it'd be nice to think it through sometime.

@null-a
Copy link
Collaborator

null-a commented Sep 28, 2017

i think it might be as simple as making success prob an arg to model, which then becomes an arg to kl.step, and changing it as we like over learning?

Yeah, I think it's as simple as that.

@eb8680 eb8680 added this to the Launch milestone Oct 2, 2017
@jpchen jpchen mentioned this issue Oct 5, 2017
@null-a null-a mentioned this issue Oct 13, 2017
@eb8680
Copy link
Member

eb8680 commented Oct 25, 2017

Closed by #259

@eb8680 eb8680 closed this as completed Oct 25, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants