In [None]:
import torch
import pyro
from pyro.infer import MCMC, NUTS
from pyro.distributions import Normal, Uniform

# Define the model
class BNN(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super().__init__()
    self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
    self.fc2 = torch.nn.Linear(hidden_dim, output_dim)

  def forward(self, x):
    x = torch.nn.functional.relu(self.fc1(x))
    return self.fc2(x)

  def guide(self, x):
    # Define prior distributions for weights and biases
    w1_prior = Normal(torch.zeros(hidden_dim, input_dim), torch.ones(hidden_dim, input_dim))
    b1_prior = Normal(torch.zeros(hidden_dim), torch.ones(hidden_dim))
    w2_prior = Normal(torch.zeros(output_dim, hidden_dim), torch.ones(output_dim, hidden_dim))
    b2_prior = Normal(torch.zeros(output_dim), torch.ones(output_dim))

    # Define guide distributions with learnable parameters
    w1_loc = pyro.sample("w1_loc", w1_prior)
    w1_scale = pyro.sample("w1_scale", Uniform(0.1, 2.0))
    b1_loc = pyro.sample("b1_loc", b1_prior)
    w2_loc = pyro.sample("w2_loc", w2_prior)
    w2_scale = pyro.sample("w2_scale", Uniform(0.1, 2.0))
    b2_loc = pyro.sample("b2_loc", b2_prior)

    # Apply distributions to model parameters
    self.fc1.weight = pyro.sample("w1", Normal(w1_loc, w1_scale))
    self.fc1.bias = pyro.sample("b1", Normal(b1_loc, torch.ones(hidden_dim)))
    self.fc2.weight = pyro.sample("w2", Normal(w2_loc, w2_scale))
    self.fc2.bias = pyro.sample("b2", Normal(b2_loc, torch.ones(output_dim)))

# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data", train=True, download=True, transform=torchvision.transforms.ToTensor()),
    batch_size=32, shuffle=True)
(x_test, y_test) = next(iter(torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("data", train=False, download=True, transform=torchvision.transforms.ToTensor()),
    batch_size=len(x_test), shuffle=False)))

# Define model and guide
input_dim = 28 * 28
hidden_dim = 128
output_dim = 10
model = BNN(input_dim, hidden_dim, output_dim)
guide = model.guide

# Define inference algorithm and loss function
nuts_kernel = NUTS(model)
loss_fn = pyro.infer.Trace_ELBO()

# Train the model
for epoch in range(5):
  for x_batch, y_batch in x_train:
    pyro.util.set_rng_seed(epoch * len(x_train) + i)
    with pyro.poutine.trace() as trace:
      predictions = model(x_batch)
      loss = torch.nn.functional.cross_entropy(predictions, y_batch)
    elbo = loss_fn(model, guide, x_batch, y_batch)
    nuts_kernel.step(trace, elbo)

# Evaluate on test data
with torch.no_grad():
  predictions = model(x_test)
  accuracy = (predictions.argmax(-1) == y_test).sum() / len(y_test)
  print("Test accuracy:", accuracy)

# Make predictions with uncertainty
# ... (requires sampling from posterior predictive distribution)

# Note: