Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Distributions should track their variables #1282

Closed
hartikainen opened this issue Mar 26, 2021 · 3 comments
Closed

[Feature Request] Distributions should track their variables #1282

hartikainen opened this issue Mar 26, 2021 · 3 comments

Comments

@hartikainen
Copy link
Contributor

Accessing the variables of a distribution, for example one transformed with a RealNVP bijectors, always results in an empty tuple. It would be really useful if these variables were accessible in the same way as they are from tf.keras.Layer and tf.keras.Model. In the script below, nvp.variables is empty whereas the tf.Model containing the distribution correctly returns the variables (model.variables).

import tensorflow as tf
import tensorflow_probability as tfp


tfd = tfp.distributions
tfb = tfp.bijectors


def make_shift_and_log_scale_fn(hidden_layer_sizes, name=None):
    mlp_body = tf.keras.Sequential([
        tf.keras.layers.Dense(hidden_layer_size)
        for hidden_layer_size in hidden_layer_sizes
    ])

    def shift_and_log_scale_fn(x, output_units):
        out = mlp_body(x)
        out = tf.keras.layers.Dense(x.shape[-1])(out)
        shift, log_scale = tf.split(out, 2, axis=-1)

        return shift, log_scale

    return shift_and_log_scale_fn

nvp = tfd.TransformedDistribution(
    distribution=tfd.MultivariateNormalDiag(loc=[[0., 0., 0.]]),
    bijector=tfb.RealNVP(
        num_masked=2,
        shift_and_log_scale_fn=make_shift_and_log_scale_fn(
            hidden_layer_sizes=(512, 512),
            name=f'RealNVP')))

x = nvp.sample()
nvp.log_prob(x)
nvp.log_prob([[0.0, 0.0, 0.0]])

y_ = tf.keras.layers.Input(
    shape=(nvp.event_shape),
    dtype=tf.float32)
log_prob_y_ = nvp.log_prob(y_)
model = tf.keras.Model(y_, log_prob_y_)
optimizer = tf.optimizers.Adam(1e-3)
model.compile(optimizer=optimizer, loss=lambda _, log_prob: -log_prob)

tf.debugging.assert_equal(nvp.variables, model.variables)
tf.debugging.assert_equal(nvp.trainable_variables, model.trainable_variables)
@hartikainen
Copy link
Contributor Author

I just found the following in the tensorflow documentation:

Note: tf.Module is the base class for both tf.keras.layers.Layer and tf.keras.Model, so everything you come across here also applies in Keras. For historical compatibility reasons Keras layers do not collect variables from modules, so your models should use only modules or only Keras layers. However, the methods shown below for inspecting variables are the same in either case.

So it seems like the current behavior is expected.

Does anyone know if there's a way to access a distribution's variables without wrapping it inside tf.keras.Model like I do above?

@hartikainen
Copy link
Contributor Author

hartikainen commented Mar 27, 2021

#946 and tensorflow/tensorflow#47264 seem to be related to this.

@hartikainen
Copy link
Contributor Author

hartikainen commented Apr 1, 2021

To answer my own question: it seems like one acceptable way of handling the variable tracking is to separately create the shift_and_scale layer and then manually assign it to the bijector so that the variables get tracked. There are several examples of such pattern in the tfp repo, for example in the glow flow:

exit_layer = coupling_bijector_fn(new_input_shape,
output_chan=ngrab)
exit_bijector_fn = self.make_bijector_fn(
exit_layer,
target_shape=target_output_shape,
scale_fn=tf.exp)
self.exit_layer = exit_layer # For variable tracking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant