Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No gradients calculated with dense variational layers #409

Closed
jacobwjs opened this issue May 15, 2019 · 16 comments
Closed

No gradients calculated with dense variational layers #409

jacobwjs opened this issue May 15, 2019 · 16 comments
Assignees

Comments

@jacobwjs
Copy link

jacobwjs commented May 15, 2019

Hi,
I'm trying to test a Bayesian approach I'm working on against some of the new variational layers (dense for now) in tfp. I'm just trying to throw together a quick and dirty working example on MNIST just to see how tfp's variational layers comparatively perform.

I'm running into troubles calculating gradients for 'DenseReparameterization' and 'DenseFlipout' layers using the latest nightly builds of tf and tfp, along with eager execution and gradient tape.

Am I missing something simple here, or is there a deeper issue?

Basic example below:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np


import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.framework import ops



tfd = tfp.distributions

print(tf.__version__)  # 2.0.0-dev20190515
print(tfp.__version__) # 0.7.0-dev20190515


training_size = 40000
(input_images, target_labels), _ = tf.keras.datasets.mnist.load_data()
total_len = len(input_images)

x_train = input_images[:training_size, :, :].reshape(training_size, 784).astype('float32') / 255
y_train = target_labels[:training_size].astype('int32')

test_size = total_len - training_size
x_test = input_images[training_size:, :, :].reshape(test_size, 784).astype('float32') / 255
y_test = target_labels[training_size:].astype('int32')

BATCH_SIZE = int(256)
num_mini_batch = (len(y_train)//BATCH_SIZE)



class DenseModel(tf.keras.Model):
    def __init__(self, dense_layer_type='vanilla'):
        super(DenseModel, self).__init__()
        

        self.layer_type = dense_layer_type

        if dense_layer_type == 'reparam':
            print("Testing reparameterization layers")
            self.dense1 = tfp.layers.DenseReparameterization(100, activation=tf.nn.relu)
            self.dense2 = tfp.layers.DenseReparameterization(10)
        elif dense_layer_type == 'flipout':
            print("Testing flipout layers")            
            self.dense1 = tfp.layers.DenseFlipout(100, activation=tf.nn.relu)
            self.dense2 = tfp.layers.DenseFlipout(10)
        else:
            print("Testing vanilla layers")
            self.dense1 = tf.keras.layers.Dense(100, activation=tf.nn.relu)
            self.dense2 = tf.keras.layers.Dense(10)
            
    def call(self, inputs):
        x = self.dense1(inputs)
        outputs = self.dense2(x)
        return outputs
    
    def predict(self, inputs):
        return tf.nn.softmax(self.call(inputs))
    
        
ops.reset_default_graph()
ops.logging.set_verbosity(ops.logging.WARN)

model = DenseModel('vanilla')
# model = DenseModel('reparam')  # No gradients
# model = DenseModel('flipout')     # No gradients


train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

epochs = 5
optimizer = tf.optimizers.Adam(learning_rate=0.005)

for epoch in range(1, epochs+1):   
    print("epoch: ", epoch)

    train_examples_shuffled = train_dataset.shuffle(50000).batch(BATCH_SIZE)
    training_set = tf.compat.v1.data.make_one_shot_iterator(train_examples_shuffled)


    for idx, (images, labels) in enumerate(training_set):

        with tf.GradientTape() as tape:

            logits = model(images)
            labels_distribution = tfd.Categorical(logits=logits)
            _neg_log_likelihood = -tf.reduce_mean(input_tensor=labels_distribution.log_prob(labels))
            
            kl_loss = np.sum(model.losses)/(training_size//BATCH_SIZE)

            total_loss = kl_loss + _neg_log_likelihood
            
            
        _grad = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(_grad, model.trainable_variables))

        
        if idx % 20 == 0:
            print("mini-batch: {}, total_loss: {}, nll: {}, kl: {}".format(idx,
                                                                           total_loss,
                                                                           _neg_log_likelihood,
                                                                           kl_loss))
    
    
    ```
@SiegeLordEx
Copy link
Member

Fixing tfp.layers.DenseFlipout and tfp.layers.DenseReparameterization is going to be pretty difficult, at least at the first glance. I'll discuss why below, but for your immediate future, I'd suggest using tfp.layers.DenseVariational, which won't have this issue. This layer is written in the new style, but unfortunately we don't have the flipout version yet. Worse, we also don't have a way to write an exact equivalent to what DenseReparameterization is doing by default in terms of its priors and approximate posteriors either. Still, here's something that should work (the difference between this and what DenseReparameterization is doing is that we're doing a version of empirical Bayes here, and partially learning the prior on the parameters, especially the bias):

def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
  n = kernel_size + bias_size
  c = np.log(np.expm1(1.))
  return tf.keras.Sequential([
      tfp.layers.VariableLayer(2 * n, dtype=dtype),
      tfp.layers.DistributionLambda(lambda t: tfd.Independent(  # pylint: disable=g-long-lambda
          tfd.Normal(loc=t[..., :n],
                     scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
          reinterpreted_batch_ndims=1)),
  ])

def prior_trainable(kernel_size, bias_size=0, dtype=None):
  n = kernel_size + bias_size
  return tf.keras.Sequential([
      tfp.layers.VariableLayer(n, dtype=dtype),
      tfp.layers.DistributionLambda(
          lambda t: tfd.Independent(tfd.Normal(loc=t, scale=1),  # pylint: disable=g-long-lambda
                                    reinterpreted_batch_ndims=1)),
  ])

self.dense1 = tfp.layers.DenseVariational(100, posterior_mean_field, prior_trainable, activation=tf.nn.relu, kl_weight=1/training_size))
self.dense2 = tfp.layers.DenseFlipout(10, posterior_mean_field, prior_trainable, kl_weight=1/training_size))

As for the reason why DenseFlipout etc don't work, it boils down to creating tensors based on variables inside the layer's build function. This was okay in TensorFlow 1.0, but bad in TensorFlow 2.0. We can possibly fix the location's gradient, but I don't see a clear way to solve the scale's gradient.

@aamini
Copy link

aamini commented May 24, 2019

We are also experiencing problems with this. Here's an even simpler example to recreate the error!

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

x_train = tf.convert_to_tensor(np.linspace(0,1,100).reshape(-1,1), tf.float32)
y_train = tf.convert_to_tensor(2.0*x_train, tf.float32)

model = tf.keras.models.Sequential([
    tfp.layers.DenseReparameterization(100, activation=tf.nn.relu),
    tfp.layers.DenseReparameterization(1),
])
optimizer = tf.optimizers.Adam(learning_rate=1e-3)
mse = tf.keras.losses.MeanSquaredError()

for i in range(100):
    with tf.GradientTape() as tape:
        y_hat = model(x_train, training=True) #forward pass
        loss = mse(y_train, y_hat)
        loss += tf.reduce_mean(model.losses)

    grads = tape.gradient(loss, model.variables) #compute gradient
    optimizer.apply_gradients(zip(grads, model.variables))
    print loss

Throws the following error after a single successful gradient update.

ValueError: No gradients provided for any variable: ['sequential/dense_reparameterization/kernel_posterior_loc:0', 'sequential/dense_reparameterization/kernel_posterior_untransformed_scale:0', 'sequential/dense_reparameterization/bias_posterior_loc:0', 'sequential/dense_reparameterization_1/kernel_posterior_loc:0', 'sequential/dense_reparameterization_1/kernel_posterior_untransformed_scale:0', 'sequential/dense_reparameterization_1/bias_posterior_loc:0'].

@SiegeLordEx
Copy link
Member

Good news, we've recently developed a way to fix these issues in a generic and backwards compatible way, a fix for the variational layers should be coming soon.

@jacobwjs
Copy link
Author

Good news, we've recently developed a way to fix these issues in a generic and backwards compatible way, a fix for the variational layers should be coming soon.

Nice! Looking forward to the fix.

@xht033
Copy link

xht033 commented Jun 21, 2019

which one should be correct when we divide the model.losses, total number of training data or number of batches?

@jacobwjs
Copy link
Author

which one should be correct when we divide the model.losses, total number of training data or number of batches?

As all things I'd say it depends, and in practice I've seen both. We assume we should be less certain about or view of the world with fewer training examples, meaning we're strongly tied to our prior (i.e. large contribution from KL loss in our total objective). In that context dividing by the total number of examples makes sense. In practice using mini-batch stochastic optimisation, I weight the KL loss term by mini-batch size and restrict that contribution on an epoch basis.

I've played with other more complex weighting strategies, and in the end I don't find large changes in my final weight distributions assuming I'm not doing anything extreme.

Blundell et al. have a pretty straightforward strategy defined in their 2015 paper if you're looking for a reference (https://arxiv.org/abs/1505.05424).

I think the more interesting question is which prior, and what don't we know about what we don't know.

@TMullerSG
Copy link

Good news, we've recently developed a way to fix these issues in a generic and backwards compatible way, a fix for the variational layers should be coming soon.

Hi, is there a way to work around this no gradient problem in tfp0.7? i tried on latest tfp, it's working fine. but for some reason i need to run my code on tf1.14 only. appreciated if some infor can be provided~ many thanks

@SiegeLordEx
Copy link
Member

SiegeLordEx commented Nov 1, 2019

There is no easy workaround, as the fix involved quite a number of changes:

  • We needed to make the Normal distribution "Tape safe".
  • We introduced a "DeferredTensor" abstraction.
  • The commit that closed this bug.

Is it conceivable for you to use a tfp-nightly package? They still have the ones going back a few months on pypy. If you grab one from sometime in August it should still only require tf1.14.

@TMullerSG
Copy link

There is no easy workaround, as the fix involved quite a number of changes:

* We needed to make the Normal distribution "Tape safe".

* We introduced a "DeferredTensor" abstraction.

* The commit that closed this bug.

Is it conceivable for you to use a tfp-nightly package? They still have the ones going back a few months on pypy. If you grab one from sometime in August it should still only require tf1.14.

Hi, thanks for the reply. I have tried the nightly version. it works ok. May i check if any approach we can further push to 1.12? i checked the nightly, even most early versions like 0.6.xx can not be loaded with tf1.12

@nbro
Copy link
Contributor

nbro commented Apr 23, 2020

@SiegeLordEx This is still a problem with the following versions of TensorFlow and TFP

tf-nightly             2.2.0.dev20200423  
tfp-nightly            0.11.0.dev20200423 

Specifically, I am getting the error

WARNING:tensorflow:Gradients do not exist for variables ['bayesian_convolution2d/kernel_prior_loc:0', 'bayesian_convolution2d/kernel_prior_untransformed_scale:0', 'bayesian_convolution2d_1/kernel_prior_loc:0', 'bayesian_convolution2d_1/kernel_prior_untransformed_scale:0', 'bayesian_convolution2d_2/kernel_prior_loc:0', 'bayesian_convolution2d_2/kernel_prior_untransformed_scale:0'] when minimizing the loss.

where BayesianConvolution2d is a custom layer that derives from Convolution2DFlipout.

Note that I am not using gradient tape. I am using tf.keras' fit function to train my model and I am using tf.config.experimental_run_functions_eagerly(True) to avoid other errors. See e.g
https://stackoverflow.com/q/61391801/3924118.

This and many other bugs in TF and TFP are far from being solved!

@SiegeLordEx
Copy link
Member

@nbro I'll take a look at the stackoverflow code, but for your trainable prior woes, could you share how you define the kernel_prior_fn? Are you using tfp.util.DeferredTensor or tfp.util.TransformedVariable?

@nbro
Copy link
Contributor

nbro commented Apr 23, 2020

@SiegeLordEx I am using the default prior that the class Convolution2DFlipout, i.e. a Gaussian. Do I need to do something else? If yes, why didn't you update Convolution2DFlipout class in the last versions of TFP?

The only different thing that my BayesianConvolution2d does with respect to Convolution2DFlipout is override kernel_divergence_fn, so that I can compute the divergence between the prior and posterior in a dynamic way, i.e. I want to update an internal variable of the custom layer BayesianConvolution2d, so that the KL divergence is scaled in a dynamic way at each step of an epoch.

I had initially opened this issue: #887. Meanwhile, I think I cannot directly change the property kernel_divergence_fn of the custom layer, that's why I was trying to change a non-trainable weight that the function that computes the KL divergence uses to compute the KL divergence in a dynamic way (i.e. based on the input)

I tried to do that and, apparently, the variable changes, but I am getting unexpected values for the KL divergence when the optimizer (or fit) prints out the loss. See https://stackoverflow.com/q/61371627/3924118. In particular, the KL divergence that I compute manually in the callback is different than the KL divergence that the optimizer prints (in the progress bar) for the training data. The KL divergence that I compute manually in the callback at each step of an epoch is more similar to the KL divergence for the validation data.

I am really lost and stuck (and I can't progress my work). I can't understand what's going on. I tried to use tf.config.experimental_run_functions_eagerly(True), I tried to use the lastest (unstable) version of TF and TFP, I tried a lot of things, but I don't understand why the KL divergence that I compute manually is different than the KL divergence displayed in the progress bar.

(Btw, I was expecting the KL divergence for the training and test sets to be similar, but maybe this is another issue).

@SiegeLordEx
Copy link
Member

I am using the default prior that the class Convolution2DFlipout, i.e. a Gaussian. Do I need to do something else? If yes, why didn't you update Convolution2DFlipout class in the last versions of TFP?

The default prior doesn't create variables, yet your error message has things like: bayesian_convolution2d/kernel_prior_loc:0. Are you using default_mean_field_normal_fn perhaps?

@nbro
Copy link
Contributor

nbro commented Apr 23, 2020

@SiegeLordEx Ha, sorry, I had forgotten about this. I am using the following function to initialise kernel_prior_fn of the layers (both convolutional and dense).

def get_constant_kernel_prior_fn(loc=0, scale=1.0):
    return tfp.layers.default_mean_field_normal_fn(loc_initializer=tf.constant_initializer(loc),
                                                   untransformed_scale_initializer=tf.constant_initializer(
                                                       tfp.math.softplus_inverse(scale).numpy()))

@cevheck
Copy link

cevheck commented Dec 7, 2021

@SiegeLordEx

There is no easy workaround, as the fix involved quite a number of changes:

  • We needed to make the Normal distribution "Tape safe".
  • We introduced a "DeferredTensor" abstraction.
  • The commit that closed this bug.

Is it conceivable for you to use a tfp-nightly package? They still have the ones going back a few months on pypy. If you grab one from sometime in August it should still only require tf1.14.

@SiegeLordEx Am i correct in interpreting your answer as DenseFlipout layer being completely functional now? I am asking this because I'm currently working on a bayesian neural network and have implemented this using these layers. I'm still working on it to obtain better results but have stumbled on some posts talking about the DenseFlipout layer being outdated (#359 (comment)_).

In short I would want to know if it's safe and correct to keep using this functionality or if it is advised to look into the DenseVariational layer.

Thanks in advance!
Cedric

@SiegeLordEx
Copy link
Member

They should be functional. Strictly speaking DenseVariational is newer, but in practice both should work with slightly different API. If you want even more options, there's tfp.experimental.nn which has even newer layers, but those are no longer based on Keras.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants