Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tensorflow probability examples #698

Merged
merged 11 commits into from
Feb 11, 2020

Conversation

Pyrsos
Copy link
Contributor

@Pyrsos Pyrsos commented Dec 24, 2019

As addressed on issue #607, the tensorflow probability examples currently use the 'compat' tensorflow subpackage for TF2 compatibility. In this PR I have updated two of the examples ('logistic_regression ' and 'bayesian_neural_network.py') to be compatible with the TF2 paradigms, using the Keras high-level API. I have kept the same functionality for plots and retained the model architectures. On the other hand I have replaced the 'tf.data' input pipeline functions with the Keras sequence class, as these can be used very easily with the 'model.(fit|evaluate|predict)_generator' functions. Also on the 'keras_bayesian_neural_network.py' I have replaced the mnist datasets path (which appears to no longer be in that location) with the keras.datasets.mnist path.

One further issue I have noticed is when eager execution is enabled, the Keras model.fit_generator operation will return the loss metric as a Tensor object, while the accuracy will be returned as a scalar (float value). For this reason the 'keras_bayesian_neural_network.py' uses the 'model.train_on_batch' operation instead (which also allows for plotting of weights/images and performing monte carlo sampling on the validation set while training).

@googlebot googlebot added the cla: yes Declares that the user has signed CLA label Dec 24, 2019
Copy link
Contributor

@davmre davmre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this contribution! Examples are really important and the TF2 switch has made a lot of ours outdated (while we've been mostly focused on updating core TFP to work with TF2), so this is awesome and much appreciated.

Sorry for being slow to get back to you --- it's been a slow time due to the Christmas / New Year's holidays. This week might still be a bit slow, but I think enough of us are back that we should be able to move forward.

The code generally looks great. I've left a few mostly stylistic comments requesting some minor changes, but overall this is a big improvement on the status quo; it'll be great to get it checked in. Maybe @jburnim can also take a quick look as a second approver?

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Jan 1, 2020

@davmre No worries about the delay, just had some spare time and thought it'd be good to give it a go. Thanks for the comments, I have addressed these and pushed the new changes. Because I have now renamed the files the git log might be a bit different, hopefully it will not be a massive issue. Let me know if there is anything else you would like me to adjust.

I will try and update some more examples in the coming weeks.

Comment on lines 310 to 356
for epoch in range(FLAGS.num_epochs):
epoch_accuracy = []
for step, (batch_x, batch_y) in enumerate(train_seq):
# Eager mode returns a Tensor objest and not a scalar
# value for the loss, therefore only the mean accuracy
# is displayed.
batch_accuracy = model.train_on_batch(
batch_x, batch_y)[1]
epoch_accuracy.append(batch_accuracy)

if step % 100 == 0:
print('Epoch: {}, Batch index: {}, Accuracy: {:.3f}'.format(
epoch, step, np.mean(epoch_accuracy)))

if (step+1) % FLAGS.viz_steps == 0:
# Compute log prob of heldout set by averaging draws from the model:
# p(heldout | train) = int_model p(heldout|model) p(model|train)
# ~= 1/n * sum_{i=1}^n p(heldout | model_i)
# where model_i is a draw from the posterior p(model|train).
print(" ... Running monte carlo inference")
probs = np.asarray([model.predict_generator(heldout_seq, verbose=1)
for _ in range(FLAGS.num_monte_carlo)])
mean_probs = np.mean(probs, axis=0)
heldout_log_prob = np.mean(np.log(mean_probs))
print(" ... Held-out nats: {:.3f}".format(heldout_log_prob))

if HAS_SEABORN:
names = [layer.name for layer in model.layers
if 'flipout' in layer.name]
qm_vals = [layer.kernel_posterior.mean()
for layer in model.layers
if 'flipout' in layer.name]
qs_vals = [layer.kernel_posterior.stddev()
for layer in model.layers
if 'flipout' in layer.name]
plot_weight_posteriors(names, qm_vals, qs_vals,
fname=os.path.join(
FLAGS.model_dir,
"epoch{}_step{:05d}_weights.png".format(
epoch, step)))
plot_heldout_prediction(heldout_seq.images, probs,
fname=os.path.join(
FLAGS.model_dir,
"epoch{}_step{}_pred.png".format(
epoch, step)),
title="mean heldout logprob {:.2f}"
.format(heldout_log_prob))
Copy link
Contributor

@nbro nbro Jan 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't use Keras' fit method? Is it because you have faced the issue #620? The code still looks unnecessarily verbose. I had actually updated this example to TF2, but I hadn't opened a pull request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for using the train_on_batch instead of the fit was because I am trying to preserve the operations of the original example, as I think these add some valuable insight on how bayesian networks are trained and their behaviour regarding uncertainty. I thought that this was the easier way to implement those instead of using some custom Keras callbacks, which I think would be more difficult to follow. When I first submitted the PR this example was working fine, still works for version 2.0. But since version 2.1 I am also getting the problem you described in issue #620.

@davmre Is there any way you might be able to point me to a starting point for dealing with issue #620. I am making the naive assumption that the fix should not be too complicated, since as @nbro says the problem appears to be caused by the Convolution2DFlipout layers (specifically the kernel divergence) and not the DenseFlipout layers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with you. I think that, as a rule of thumb, people should use fit and callbacks are perfectly fine and intuitive. If you decide to use fit, you can solve that issue by following the instructions here: tensorflow/tensorflow#33729 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nbro I am not particularly fond of this solution because I can similarly solve the issue by disabling eager execution. In both cases the underlying cause is still not treated. However, I will make the changes you suggest (both using the fit function and callbacks) if @davmre also agrees with that implementation.

Copy link
Contributor

@nbro nbro Jan 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"In both cases the underlying cause is still not treated.", because this is a bug in TensorFlow! Hopefully, someone is already taking care of this issue. experimental_run_tf_function, in fact, exists because there's an experimental feature related to distributed training that may not work in all cases (see #519 (comment)).

…hon.eager.core._SymbolicException in Conv2DFlipout
@brianwa84
Copy link
Contributor

@davmre do you want to follow-up & import this, or need any more changes?

@nbro
Copy link
Contributor

nbro commented Feb 1, 2020

You should not merge this code. I have a better version than this one. In fact, I've noticed that, in this new version of the code, the Bayesian model is trained with the categorical cross-entropy, but this won't work well (the model will not converge with the CCE). You'd better create a custom function that calculates the negative log-likelihood (which I did in my updated version). I may perform the pull request later.

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 1, 2020

@nbro I am testing this script and it is converging for me. Can you let me know what happens when you run it? As I stated before I am happy to make any changes that the reviewers @brianwa84 @davmre suggest.

@nbro
Copy link
Contributor

nbro commented Feb 1, 2020

@Pyrsos What accuracy do you obtain after the first epoch? And what do you mean by converging? Do you mean that the loss decreases?

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 1, 2020

@nbro Test set accuracy after the first epoch is above 95%. And, yes when I mentioned script is converging I was referring to the model loss decreasing (sorry should have been clearer).

@nbro
Copy link
Contributor

nbro commented Feb 1, 2020

@Pyrsos I executed part of your code, and, as you say, the accuracy is "decent" at the end of the first epoch. I have an example where the accuracy remains at 10% throughout training (even though the loss decreases), but this was due to the fact that I was not using the softmax function in the last layer.

Anyway, given your concerns about the compatibility with the original example (and as I said above), you should use a loss function that uses the tfd.Categorical rather than the CCE.

@nbro
Copy link
Contributor

nbro commented Feb 1, 2020

@Pyrsos How do the distributions of the weights (mean and std) of each layer look like after, say, the 1st, 5th and 10th epoch? Have you noticed any particular pattern? For example, the means do not change much and actually tend to concentrate around zero, while the variances tend to increase (spread out)?

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 3, 2020

@nbro I am glad that the code worked for you. With regard to the weights, I have noticed the same behaviour you are describing, but have not gone much deeper in experimenting with this. It is definitely an interesting research area to check, especially when considering different prior/posterior distribution settings.

With regard to the loss function, I think that the terms negative log likelihood and cross entropy are interchangeable. In the documentation for the DenseFlipout (https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DenseFlipout), the negative log likelihood is defined as:

neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
    labels=labels, logits=logits)

Also, for the tfd.Categorical class, when the log_prob method is called, the sparse_softmax_cross_entropy_with_logits is returned.

def _log_prob(self, k):
logits = self.logits_parameter()
if self.validate_args:
k = distribution_util.embed_check_integer_casting_closed(
k, target_dtype=self.dtype)
k, logits = _broadcast_cat_event_and_params(
k, logits, base_dtype=dtype_util.base_dtype(self.dtype))
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=k, logits=logits)

In that case the sparse version is used, which explains why the inputs had to be integer values instead of one-hot (categorical) values. Notice also that the function returns the negative value, which is why the original example had to compensate by doing:
# Compute the -ELBO as the loss, averaged over the batch size.
neg_log_likelihood = -tf.reduce_mean(
input_tensor=labels_distribution.log_prob(labels))

The keras categorical_cross_entropy string when passed to the model.compile method also calls the softmax_cross_entropy_with_logits_v2 (I guess the v2 is a remnant from TF v1).
https://github.com/tensorflow/tensorflow/blob/cf7fcf164c9846502b21cebb7d3d5ccf6cb626e8/tensorflow/python/keras/backend.py#L4487-L4504

So I think that is what is happening, but please correct me if I have this wrong and there is something else going on @brianwa84 @davmre

@nbro
Copy link
Contributor

nbro commented Feb 3, 2020

@Pyrsos In any case, TFP provides abstractions for distributions, which should be used (in an example that shows the usage of TFP), no? I know that NLL is equivalent to CE.

@Pyrsos Pyrsos requested a review from brianwa84 February 4, 2020 08:28
@nbro
Copy link
Contributor

nbro commented Feb 5, 2020

@Pyrsos I noticed that in your updated example https://github.com/Pyrsos/probability/blob/update_examples/tensorflow_probability/examples/bayesian_neural_network.py, you are using compile(..., experimental_run_tf_function=False), but the option experimental_run_tf_function=False is (only?) required if you use fit to train the model, which isn't the case, given that you are still using the hardcoded training loop. Or do you also get the error described in the issue tensorflow/tensorflow#33729 if you use train_on_batch? I haven't tried it, so it's possible this problem also occurs with train_on_batch.

@nbro
Copy link
Contributor

nbro commented Feb 5, 2020

@brianwa84, @davmre Given that there are two models under the folder models, wouldn't it also be opportune to write the definition of the basic Bayesian CNN (and, actually, the definition of any Bayesian model) in a standalone module under this folder models, which would then be imported from the corresponding example?

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 5, 2020

@nbro Yes, the error you are describing also appears in train_on_batch, which is why as you suggested I used the experimental_run_tf_function on model.compile. As far as I can tell it is associated with the divergence kernels of the CNN layers. Dense layers appear to be unaffected.

@nbro
Copy link
Contributor

nbro commented Feb 5, 2020

@Pyrsos No, you are confused: I hadn't suggested using experimental_run_tf_function=False because you were using train_on_batch. In fact, you were using train_on_batch before my suggestion (see https://github.com/Pyrsos/probability/blob/a945f418533b5827ff1966c8186e58604ceb43fc/tensorflow_probability/examples/bayesian_neural_network.py, where you are using no experimental_run_tf_function=False, but using train_on_batch), and you didn't say you were getting any error. I suggested you use experimental_run_tf_function=False in the case you decided to convert that hardcoded training loop to a simple call to fit (my recommendation).

So, I will ask you again: do you get that error with train_on_batch and experimental_run_tf_function=True?

And, yes, as I had already reported in several Github issues, yes, the issue is only apparently related to Bayesian convolution layers.

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 5, 2020

@nbro If you set experimental_run_tf_function=True with train_on_batch you will get the error you are describing here tensorflow/tensorflow#33729. The error is not exclusive to the fit operation.

@nbro
Copy link
Contributor

nbro commented Feb 5, 2020

@Pyrsos So, why did you create this pull request if, initially, you were not using experimental_run_tf_function=False, but were using train_on_batch, so you were getting an error?

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 5, 2020

@nbro Original version of the script was using v2.0 where train_on_batch did not produce an error. The error originated when I updated to v2.1, at which point I passed experimental_run_tf_function=False to the compile function.

@nbro
Copy link
Contributor

nbro commented Feb 5, 2020

@Pyrsos It's strange. You're the second person saying that experimental_run_tf_function=False was added in TF 2.1, but I had been using this option in TF 2.0 too.

@davmre
Copy link
Contributor

davmre commented Feb 11, 2020

Sorry for the long delay on this; it fell through the cracks partly because I was on vacation for a while. We're putting this through internal review now; hopefully it will go in within the next couple of days.

A couple of notes:

tensorflow-copybara pushed a commit that referenced this pull request Feb 11, 2020
@tensorflow-copybara tensorflow-copybara merged commit ff173d9 into tensorflow:master Feb 11, 2020
@nbro
Copy link
Contributor

nbro commented Feb 11, 2020

@davmre The merged examples now have both options experimental_run_tf_function=False and your suggestion model.build(...). Maybe this was not intended.

@Pyrsos
Copy link
Contributor Author

Pyrsos commented Feb 13, 2020

Thank you very much @davmre ! I learned a lot from this process, and I am looking forward to contributing again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Declares that the user has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants