# Bayesian Flows

Bayesian networks are similar to normal Linear models, where each weight is replaced by a learnable gaussian distribution. 

At each evaluation, a bayesian model performs a sampling of its weights values and returns a different result. 

During the training the average and std of the weights is learnt to reproduce the target posterior distribution.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from sklearn.datasets import make_moons
from torch import nn

import zuko

Each model can be made "bayesian" by using the `BayesianModel` wrapper from zuko. 

In [None]:
from zuko.bayesian import BayesianModel

adjacency = torch.randn(4, 3) < 0
net = zuko.nn.MaskedMLP(adjacency, [16, 32], activation=nn.ELU)

# Create a Bayesian version of the network
# init_logvar controls the initial uncertainty of the weights
bnet = BayesianModel(net, init_logvar=-3.0)
bnet

The bayesian model needs to be "sampled" before using it. 
There are two ways of using the model:
- create a sampled model instance: this creates a copy of the original model, by replacing the Linear layers weights with samples from the bayesian model. This model does not propagate the gradients to the original model's parameters, so it must not be used for training.
- a **context manager**: this mode does not create a copy of the model but it replaces on the fly the forward method of the linear layer in order to sample the weights from the bayesian model. This is **the recommended** way of using the Bayesian model as it is much more memory efficient.


In [None]:
sampled_model = bnet.sample_model()

x = torch.randn(3)
jac = torch.autograd.functional.jacobian(sampled_model, x)
jac

In [None]:
with bnet.reparameterize() as sampled_model:
    x = torch.randn(3)
    print(torch.autograd.functional.jacobian(sampled_model, x))

At each evaluation the weights of the bayesian MLP are sampled from their gaussian distribution, producing a different value.

In [None]:
for i in range(10):
    with bnet.reparameterize() as sampled_model:
        print(sampled_model(x))

## Creating Bayesian flows

Any flow can be converted to a Bayesian flow. Each parameter will be converted to a gaussian distribution of parameters. 

In [None]:
model = zuko.flows.spline.NSF(
    features=3, context=5, bins=10, transforms=3, hidden_features=[64, 64]
)

bmodel = BayesianModel(model, init_logvar=-8.0)
bmodel

In [None]:
c = torch.rand((10, 5))
x = torch.rand((10, 3))

In [None]:
for i in range(3):
    with bmodel.reparameterize() as model:
        print(model(c).log_prob(x))

The log probability estimated with a baysian flow is a distribution,

In [None]:
o = []
for i in range(100):
    with bmodel.reparameterize() as model:
        o.append(model(c).log_prob(x)[0])

plt.hist(torch.stack(o).squeeze().cpu().detach().numpy())

Parameters can be explicitely included or excluded using patter-matching expressions while defining the bayesian model. For example, to make only the last layer of the each transformation bayesian one can use this expression:

In [None]:
model = zuko.flows.spline.NSF(
    features=3, context=5, bins=10, transforms=3, hidden_features=[64, 64]
)

bmodel = BayesianModel(model, init_logvar=-8.0, include_params=["transform.transforms.*.hyper.4"])
bmodel

To exclude all the bias parameters

In [None]:
model = zuko.flows.spline.NSF(
    features=3, context=5, bins=10, transforms=3, hidden_features=[64, 64]
)

bmodel = BayesianModel(model, init_logvar=-8.0, exclude_params=["**.bias"])
bmodel

# Training

Let's learn a bayesian flow over the classical moon example. We want to evaluate the uncertainty assigned by the flow to the density estimation in each point of the phasespace

In [None]:
# Generate the two-moon dataset
Nevents = 500000
X, y = make_moons(n_samples=Nevents, noise=0.2, random_state=42)

# Separate the data points by class
X_class0 = X[y == 0]
X_class1 = X[y == 1]
plt.scatter(X_class0[:, 0], X_class0[:, 1])
plt.scatter(X_class1[:, 0], X_class1[:, 1])

In [None]:
x = torch.from_numpy(X).to(torch.float32)
c = torch.from_numpy(y).to(torch.float32).unsqueeze(-1)

In [None]:
model = zuko.flows.spline.NSF(
    features=2,
    context=1,
    bins=5,
    transforms=2,
    hidden_features=[32, 32],
)

bmodel = BayesianModel(model, init_logvar=-9.0)

### Training

In [None]:
opt = torch.optim.AdamW(bmodel.parameters(), lr=1e-4)
batch_size = 256
nepochs = 5

In [None]:
log = []
indices = np.arange(Nevents)
for e in range(nepochs):
    np.random.shuffle(indices)
    for i in range(Nevents // batch_size):
        x_i = x[indices[i * batch_size : (i + 1) * batch_size]]
        c_i = c[indices[i * batch_size : (i + 1) * batch_size]]
        # print(x.shape)
        with bmodel.reparameterize() as smodel:
            # sample the model
            flow_loss = -smodel(c_i).log_prob(x_i).mean()

        # The KL divergence is automatically computed for all
        # Bayesian layers in the model when calling this method
        kl_loss = bmodel.kl_divergence(prior_var=1.0) / Nevents
        loss = flow_loss + kl_loss

        opt.zero_grad()
        loss.backward()
        opt.step()
        log.append((loss.item(), flow_loss.item(), kl_loss.item()))
        if i % 500 == 0:
            print(
                f"epoch={e}, step={i}, total loss: {loss.item():.3f}, flow loss:{flow_loss.item():.3f}, KL loss: {kl_loss.item():.3f}"
            )

In [None]:
plt.plot([i[0] for i in log], label="total loss")
plt.plot([i[1] for i in log], label="flow loss")
plt.legend();

In [None]:
plt.plot([i[2] for i in log], label="KL loss")
plt.legend();

In [None]:
model.eval()
with torch.no_grad():
    with bmodel.reparameterize() as smodel:
        samples_1 = smodel(torch.ones((20000, 1))).sample((1,)).cpu().squeeze().numpy()
        samples_2 = smodel(torch.zeros((20000, 1))).sample((1,)).cpu().squeeze().numpy()

plt.scatter(samples_1[:, 0], samples_1[:, 1])
plt.scatter(samples_2[:, 0], samples_2[:, 1])

In [None]:
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np


def profiled_histogram_2d(
    data_x,
    data_y,
    values,
    bins=10,
    range=None,
    title="2D Profiled Histogram",
    xlabel="X Bins",
    ylabel="Y Bins",
    cbar_label="Average Value",
    norm=None,
    vmax=2,
):
    # Calculate the 2D histogram with weights
    counts, x_edges, y_edges = np.histogram2d(
        data_x, data_y, bins=bins, range=range, weights=values
    )

    # Calculate the unweighted 2D histogram to handle empty bins correctly
    counts_unweighted, _, _ = np.histogram2d(data_x, data_y, bins=bins, range=range)

    # Calculate the mean value for each bin, handling empty bins
    bin_means = np.where(counts_unweighted > 0, counts / counts_unweighted, np.nan)

    # Plot the 2D profiled histogram
    # Use pcolormesh to create the 2D grid of bins with colors representing the mean values
    mesh = plt.pcolormesh(
        x_edges,
        y_edges,
        bin_means.T,  # Note the transpose (.T)
        cmap="viridis",  # You can choose a different colormap
        norm=matplotlib.colors.Normalize(vmin=np.nanmin(bin_means), vmax=vmax),
    )  # set the color scale
    plt.colorbar(mesh)

In [None]:
bmodel.eval()

x_test = x[(c == 1).squeeze(-1)][0:100000]
c_test = torch.ones((x_test.shape[0], 1))

densities = []
with torch.no_grad():
    for i in range(30):
        with bmodel.reparameterize() as smodel:
            densities.append(smodel(c_test).log_prob(x_test))

D = torch.stack(densities, dim=1)
D_mean = D.mean(dim=1)
D_std = D.exp().std(dim=1)

Average probability density over the space

In [None]:
profiled_histogram_2d(
    x_test[:, 0].cpu().numpy(),
    x_test[:, 1].cpu().numpy(),
    D_mean.exp().cpu().numpy(),
    bins=100,
    vmax=0.8,
)

Average probability density uncertainty over the space

In [None]:
profiled_histogram_2d(
    x_test[:, 0].cpu().numpy(), x_test[:, 1].cpu().numpy(), D_std.cpu().numpy(), bins=100, vmax=0.5
)

Average relative uncertainty of the probability density in the phasespace

In [None]:
profiled_histogram_2d(
    x_test[:, 0].cpu().numpy(),
    x_test[:, 1].cpu().numpy(),
    D_std.cpu().numpy() / D_mean.exp().cpu().numpy(),
    bins=100,
    vmax=1,
)

# Looking at the parameters of the model

In [None]:
bmodel

In [None]:
logsvars = []
for n, p in bmodel.logvars.items():
    logsvars.append(p.detach().cpu().numpy().flatten())

logsvars = np.concatenate(logsvars, axis=0)
logsvars

In [None]:
plt.hist(logsvars, bins=50)
plt.xlabel("Variance of the weight distribution")