In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import numpy as np
import corner

tfd = tfp.distributions
tfpl = tfp.layers
tfk = tf.keras

# Load and preprocess the data
(x_train, y_train), (x_test, y_test) = tfk.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = tfk.utils.to_categorical(y_train, num_classes=10)
y_test = tfk.utils.to_categorical(y_test, num_classes=10)

# Define the prior and posterior distributions
def prior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    return tfd.Independent(tfd.Normal(loc=tf.zeros(n), scale=1), reinterpreted_batch_ndims=1)

def posterior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    return tfd.Independent(tfd.Normal(loc=tf.Variable(tf.random.normal([n])),
                                       scale=tf.Variable(tf.ones([n]))),
                           reinterpreted_batch_ndims=1)

# Define the Bayesian neural network model
def build_bnn(input_shape, output_dim):
    inputs = tfk.Input(shape=input_shape)
    x = tfpl.Convolution2DFlipout(32, kernel_size=(3, 3), activation=tf.nn.relu)(inputs)
    x = tfk.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tfpl.Convolution2DFlipout(64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
    x = tfk.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = tfpl.Convolution2DFlipout(64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
    x = tfk.layers.Flatten()(x)
    x = tfpl.DenseFlipout(64, activation=tf.nn.relu)(x)
    outputs = tfpl.DenseFlipout(output_dim)(x)
    model = tfk.Model(inputs=inputs, outputs=outputs)
    return model

# Build the Bayesian neural network
bnn = build_bnn((28, 28, 1), 10)

# Define the negative log-likelihood loss function
def negative_log_likelihood(y_true, y_pred):
    return -tf.reduce_mean(tf.reduce_sum(y_true * tf.math.log(y_pred), axis=-1))

# Variational Inference
def run_variational_inference():
    # Define the KL divergence regularizer
    kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) /  # pylint: disable=g-long-lambda
                              tf.cast(x_train.shape[0], dtype=tf.float32))

    # Compile the model with the negative log-likelihood loss and KL divergence regularizer
    bnn.compile(optimizer=tfk.optimizers.Adam(learning_rate=0.001),
                loss=negative_log_likelihood,
                metrics=[tfk.metrics.CategoricalAccuracy()],
                experimental_run_tf_function=False)

    # Train the model with variational inference
    bnn.fit(x_train, y_train,
            batch_size=128,
            epochs=10,
            validation_data=(x_test, y_test),
            callbacks=[tfp.keras.callbacks.KLDivergenceAddLoss(kl_divergence_function)])

    # Evaluate the model on the test set
    test_loss, test_acc = bnn.evaluate(x_test, y_test)
    print(f"Variational Inference - Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

# Hamiltonian Monte Carlo
def run_hmc():
    # Run HMC to sample from the posterior distribution
    num_burnin_steps = 500
    num_results = 500
    num_leapfrog_steps = 50
    step_size = 0.03

    # Define the joint log probability function
    def joint_log_prob(model_params):
        prior_log_prob = tf.reduce_sum(prior.log_prob(model_params))
        logits = bnn(x_train)
        log_likelihood = tf.reduce_sum(y_train * tfk.backend.log(logits), axis=-1)
        return prior_log_prob + log_likelihood

    # Initialize the HMC kernel
    adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
        tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=joint_log_prob,
            num_leapfrog_steps=num_leapfrog_steps,
            step_size=step_size),
        num_adaptation_steps=num_burnin_steps)

    # Run the HMC chain
    @tf.function
    def run_chain():
        return tfp.mcmc.sample_chain(
            num_results=num_results,
            num_burnin_steps=num_burnin_steps,
            current_state=bnn.trainable_variables,
            kernel=adaptive_hmc)

    samples, _ = run_chain()

    # Make predictions using the HMC samples
    y_pred_samples = []
    for i in range(num_results):
        bnn.set_weights(samples[i])
        y_pred_samples.append(bnn(x_test))

    y_pred_mean = tf.reduce_mean(y_pred_samples, axis=0)
    y_pred_classes = tf.argmax(y_pred_mean, axis=-1)

    accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_classes, tf.argmax(y_test, axis=-1)), tf.float32))
    print(f"HMC - Test Accuracy: {accuracy:.4f}")

    # Plot the corner plot to visualize uncertainties
    flat_samples = np.array([tf.reshape(sample, [-1]).numpy() for sample in samples])
    fig = corner.corner(flat_samples)
    plt.show()

# Run variational inference
run_variational_inference()

# Run Hamiltonian Monte Carlo
run_hmc()