In [1]:
# Install TensorFlow Probability
!pip install tensorflow-probability

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers

# Load the data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
x_train = x_train.reshape(-1, 28 * 28).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype('float32') / 255.0
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

# Define the Bayesian neural network model
def build_bnn(input_dim, output_dim):
    model = tf.keras.Sequential([
        tfpl.DenseFlipout(512, activation=tf.nn.relu, input_shape=(input_dim,)),
        tfpl.DenseFlipout(256, activation=tf.nn.relu),
        tfpl.DenseFlipout(output_dim, activation=None)
    ])
    return model

# Create the Bayesian neural network
bnn = build_bnn(28 * 28, 10)

# Define the loss function and metrics
neg_log_likelihood = lambda y, logits: -tfd.Categorical(logits=logits).log_prob(tf.argmax(y, axis=-1))
accuracy = tf.keras.metrics.CategoricalAccuracy()

# Train the Bayesian neural network
num_epochs = 5
batch_size = 128

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
for epoch in range(num_epochs):
    for i in range(0, len(x_train), batch_size):
        x_batch = x_train[i:i+batch_size]
        y_batch = y_train[i:i+batch_size]

        with tf.GradientTape() as tape:
            logits = bnn(x_batch)
            loss = neg_log_likelihood(y_batch, logits)

        gradients = tape.gradient(loss, bnn.trainable_variables)
        optimizer.apply_gradients(zip(gradients, bnn.trainable_variables))

        accuracy.update_state(y_batch, logits)

    print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy.result().numpy()}")
    accuracy.reset_states()

# Evaluate on the test set
logits = bnn(x_test)
test_loss = neg_log_likelihood(y_test, logits)
test_acc = tf.keras.metrics.CategoricalAccuracy()(y_test, logits)
print(f"Test Loss: {test_loss.numpy()}, Test Accuracy: {test_acc.numpy()}")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


  loc = add_variable_fn(
  untransformed_scale = add_variable_fn(


Epoch 1/5, Accuracy: 0.8822500109672546
Epoch 2/5, Accuracy: 0.9518833160400391
Epoch 3/5, Accuracy: 0.9667333364486694
Epoch 4/5, Accuracy: 0.9749166369438171
Epoch 5/5, Accuracy: 0.9791666865348816
Test Loss: [3.5762781e-07 8.4638241e-06 4.9345139e-03 ... 4.7683704e-07 1.6927575e-05
 1.0967195e-05], Test Accuracy: 0.9713000059127808
