In [2]:
import tensorflow as tf 
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import trange
import numpy as np

In [3]:
# Define the prior weight distribution as Normal of mean=0 and stddev=1.
# Note that, in this example, the we prior distribution is not trainable,
# as we fix its parameters.
def prior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    prior_model = tf.keras.Sequential(
        [
            tfp.layers.DistributionLambda(
                lambda t: tfp.distributions.MultivariateNormalDiag(
                    loc=tf.zeros(n), scale_diag=tf.ones(n)
                )
            )
        ]
    )
    return prior_model

In [4]:
# Define variational posterior weight distribution as multivariate Gaussian.
# Note that the learnable parameters for this distribution are the means,
# variances, and covariances.
def posterior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    posterior_model = tf.keras.Sequential(
        [
            tfp.layers.VariableLayer(
                tfp.layers.MultivariateNormalTriL.params_size(n), dtype=dtype
            ),
            tfp.layers.MultivariateNormalTriL(n),
        ]
    )
    return posterior_model

In [5]:
def get_model(layers, train_sz, activation):
    model = tf.keras.Sequential()
    model.add(
        tfp.layers.DenseVariational(
            units=layers[0],
            make_posterior_fn=posterior,
            make_prior_fn=prior,
            kl_weight=1 / train_sz,
            activation=activation,
        )
    )
    for units in layers[1:-1]:
        model.add(
            tfp.layers.DenseVariational(
            units=units,
            make_posterior_fn=posterior,
            make_prior_fn=prior,
            kl_weight=1 / train_sz,
            activation=activation,
            )
        )
    #model.add(tf.keras.layers.Dense(units=layers[-1], activation=None))

    model.add(
        tfp.layers.DenseVariational(
        units=layers[-1],
        make_posterior_fn=posterior,
        make_prior_fn=prior,
        kl_weight=1 / train_sz,
        activation=None,
        )
    )

    return model

In [6]:
def compute_predictions(model, x, num_results):
    predictions = []
    X = []
    for i in trange(num_results, desc="Computing predictions"):
        predictions.append(model(x).numpy().squeeze(-1))
        X.append(x.numpy().squeeze(-1))
    X = np.array(X)
    predictions = np.array(predictions)
    return X, predictions

In [7]:
n_train = 100
layers = [10, 1]
model = get_model(layers, n_train, activation="relu")
model.compile(
    optimizer="adam",
    loss="mse",
)

Metal device set to: Apple M1


2021-12-07 13:44:27.137416: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-12-07 13:44:27.137615: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [8]:
f = lambda x: tf.math.sin(x)
x_train = tf.random.normal(shape=(n_train, 1), mean=0., stddev=2.)
y_train = f(x_train)

In [9]:
model.fit(x=x_train, y=y_train, batch_size=100, epochs=100, verbose=0)

Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.


2021-12-07 13:44:29.024896: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2021-12-07 13:44:29.141504: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2021-12-07 13:44:29.635432: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


<keras.callbacks.History at 0x28f8dcbb0>

In [10]:
num_results = 100
n_test = 1
x_test = tf.random.normal(shape=(n_test, 1), mean=0., stddev=2.)

In [None]:
model(x_test)

In [None]:
X, predictions = compute_predictions(model, x_test, num_results=num_results)
print(predictions)
print(predictions.shape)


sns.lineplot(X.ravel(), predictions.ravel())

x = np.linspace(-2 * np.pi, 2 * np.pi, 1001)
plt.plot(x, f(x), label="True function")
plt.xlabel("x")
plt.ylabel("y")
plt.show()


Computing predictions:   1%|█                                                                                                         | 1/100 [00:00<00:26,  3.69it/s]