# Learning posteriors for BNN on MNIST data

## Prepare

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install torch torchvision torchmetrics tqdm pandas matplotlib
# !pip install "numpy<2.0" 

## Imports

In [3]:
import numpy as np

import torch
import torch.nn as nn
from torchmetrics.functional.classification import multiclass_calibration_error

import argparse
import sys
import time
import tqdm
import gc

In [4]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

In [5]:
import sys

sys.path.append("../")

In [6]:
import reparameterized
from reparameterized import sampling

from reparameterized.likelihoods import (
    categorical_posterior_probs as posterior_predictions,
)
from reparameterized.likelihoods import categorical_log_prob
from reparameterized.bnn_wrapper import elbo_mc



## Config

In [7]:
if "ipykernel" in sys.argv[0]:
    sys.argv = [""]

In [None]:
parser = argparse.ArgumentParser(description="Reparameterized: BNN+MNIST")

parser.add_argument("--name", type=str, default=None)

parser.add_argument("--dataset", type=str, default="mnist", choices=("mnist"))
parser.add_argument("--batch_size", type=int, default=1024)

parser.add_argument("--optimizer", type=str, default="Adam")
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--n_posterior_samples", type=int, default=17)
parser.add_argument("--n_epochs", type=int, default=100)

parser.add_argument(
    "--posterior_model_architecture",
    type=str,
    default="svd_small_rnvp_rezero_on_residuals",
    choices=(
        "rnvp_rezero",
        "rnvp",
        "rnvp_rezero_small",
        "rnvp_small",
        "factorized_gaussian",
        "factorized_gaussian_rezero",
        "gaussian_tril",
        "gaussian_tril_rezero",
        "gaussian_full",
        "gaussian_full_rezero",
        "gaussian_lowrank",        
        "gaussian_lowrank_rezero",        
        "svd_rnvp",
        "svd_rnvp_rezero",
        "svd_rnvp_rezero_on_residuals",
        "svd_rnvp_small",
        "svd_rnvp_small_rezero",
        "svd_rnvp_small_rezero_on_residuals",
        "svd_gaussian_lowrank",
        "svd_factorized_gaussian"
    ),
)
parser.add_argument(
    "--distributional_parameters", type=str, nargs="+", default=["last_layer.weight", "common_layers.2.0"]
)
parser.add_argument("--joint_sampling", default=True, action="store_false")

parser.add_argument("--seed", type=int, default=1863)
parser.add_argument("--wandb", default=False, action="store_true")

_StoreTrueAction(option_strings=['--wandb'], dest='wandb', nargs=0, const=True, default=False, type=None, choices=None, required=False, help=None, metavar=None)

In [9]:
cfg = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
if cfg.name is None:
    distributional_parameters_str = "_".join(cfg.distributional_parameters)
    cfg.name = f"{cfg.dataset}_{cfg.posterior_model_architecture}_seed{cfg.seed}_N{cfg.n_posterior_samples}_E{cfg.n_epochs}_lr{cfg.lr}_P{distributional_parameters_str}_{cfg.optimizer}"

In [11]:
def wandb_log(*args, **kwargs):
    pass


if cfg.wandb:
    import wandb

    wandb.init(project="Reparameterized: BNN+MNIST", config=cfg.__dict__, name=cfg.name)

    def wandb_log(*args, **kwargs):
        wandb.log(*args, **kwargs)

In [12]:
print(f"cfg = {cfg}")

cfg = Namespace(name='mnist_svd_small_rnvp_rezero_on_residuals_seed1863_N17_E100_lr0.001_Plast_layer.weight_common_layers.2.0_Adam', dataset='mnist', batch_size=1024, optimizer='Adam', lr=0.001, n_posterior_samples=17, n_epochs=100, posterior_model_architecture='svd_small_rnvp_rezero_on_residuals', distributional_parameters=['last_layer.weight', 'common_layers.2.0'], joint_sampling=True, seed=1863, wandb=False)


## Data and Model

In [13]:
def get_mnist_dataloaders(batch_size=128):
    mnist_train_dataset = torchvision.datasets.MNIST(
        root="data",
        train=True,
        transform=transforms.ToTensor(),
        download=True,
    )

    mnist_test_dataset = torchvision.datasets.MNIST(
        root="data",
        train=False,
        transform=transforms.ToTensor(),
    )

    mnist_ood_dataset = torchvision.datasets.FashionMNIST(
        root="data",
        train=False,
        transform=transforms.ToTensor(),
        download=True,
    )

    dataloader_train = DataLoader(
        mnist_train_dataset, batch_size=batch_size, shuffle=True
    )
    dataloader_test = DataLoader(
        mnist_test_dataset, batch_size=batch_size, shuffle=False
    )
    dataloader_ood = DataLoader(mnist_ood_dataset, batch_size=batch_size, shuffle=False)

    return dataloader_train, dataloader_test, dataloader_ood

In [14]:
class BNNMnist(nn.Module):
    def __init__(
        self,
        in_dim=784,
        out_dim=10,
        hid_dim=128,
        num_layers=2,
        device=torch.device("cuda"),
    ):
        super().__init__()

        self.device = device

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hid_dim = hid_dim
        self.num_layers = num_layers

        self.common_layers = nn.Sequential(
            nn.Linear(self.in_dim, self.hid_dim),
            *[
                nn.Sequential(nn.Linear(self.hid_dim, self.hid_dim), nn.ELU())
                for _ in range(self.num_layers)
            ],
        ).to(self.device)

        self.last_layer = nn.Linear(self.hid_dim, self.out_dim).to(self.device)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.common_layers(x)
        x = self.last_layer(x)
        return x

In [15]:
if cfg.dataset == "mnist":
    train_dataloader, test_dataloader, _ = get_mnist_dataloaders(
        batch_size=cfg.batch_size
    )
    bnn = BNNMnist(in_dim=784, out_dim=10, hid_dim=128, num_layers=2, device=device)

else:
    raise ValueError(f"Dataset={cfg.dataset} not supported!")

## Priors and Posteriors

In [16]:
pointwise_params = {}
distributional_params = {}

print("Model parameters:")
for n, p in bnn.named_parameters():

    distributional = any(
        (param_selector in n) for param_selector in cfg.distributional_parameters
    )

    print(f" - {n}: {p.shape} (grad={p.requires_grad}) (learn distribution={distributional})")

    if distributional:
        distributional_params[n] = p
    else:
        pointwise_params[n] = p

Model parameters:
 - common_layers.0.weight: torch.Size([128, 784]) (grad=True) (learn distribution=False)
 - common_layers.0.bias: torch.Size([128]) (grad=True) (learn distribution=False)
 - common_layers.1.0.weight: torch.Size([128, 128]) (grad=True) (learn distribution=False)
 - common_layers.1.0.bias: torch.Size([128]) (grad=True) (learn distribution=False)
 - common_layers.2.0.weight: torch.Size([128, 128]) (grad=True) (learn distribution=True)
 - common_layers.2.0.bias: torch.Size([128]) (grad=True) (learn distribution=True)
 - last_layer.weight: torch.Size([10, 128]) (grad=True) (learn distribution=True)
 - last_layer.bias: torch.Size([10]) (grad=True) (learn distribution=False)


In [17]:
def create_gaussian_basic_prior(p):
    p = p.flatten()
    return torch.distributions.MultivariateNormal(
        loc=torch.zeros_like(p), covariance_matrix=torch.diag(torch.ones_like(p))
    )


priors = {n: create_gaussian_basic_prior(p) for n, p in distributional_params.items()}


def log_priors(samples):
    return sum(
        priors[n].log_prob(p.flatten()) for n, p in samples.items()
    )  # sum over all parameters

In [18]:
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)

In [19]:
if not cfg.joint_sampling:
    print("Each parameter gets its own sampler")
    parameter2sampler, variational_params, aux_objs = (
        sampling.create_independent_samplers(
            distributional_params, cfg.posterior_model_architecture
        )
    )
    sampler = reparameterized.parameter_samplers_to_joint_sampler(parameter2sampler)

else:
    print("All parameters are put together and use a joint sampler")
    sampler, variational_params, aux_objs = sampling.create_joint_sampler(
        distributional_params, cfg.posterior_model_architecture
    )

print(f"Posterior={cfg.posterior_model_architecture}:")
print(f" - target_params = {list(distributional_params.keys())}")
print(f" - sampler = {sampler}")
print(f" - variational_params = {list(variational_params.keys())}")
print(f" - aux_objs = {list(aux_objs.keys())}")



All parameters are put together and use a joint sampler
Posterior=svd_small_rnvp_rezero_on_residuals:
 - target_params = ['common_layers.2.0.weight', 'common_layers.2.0.bias', 'last_layer.weight']
 - sampler = <function create_multiparameter_svd_sampler.<locals>.sampler at 0x12e47e200>
 - variational_params = ['alpha', 'beta', 't.0.0.weight', 't.0.0.bias', 't.0.2.weight', 't.0.2.bias', 't.0.4.weight', 't.0.4.bias', 't.1.0.weight', 't.1.0.bias', 't.1.2.weight', 't.1.2.bias', 't.1.4.weight', 't.1.4.bias', 't.2.0.weight', 't.2.0.bias', 't.2.2.weight', 't.2.2.bias', 't.2.4.weight', 't.2.4.bias', 't.3.0.weight', 't.3.0.bias', 't.3.2.weight', 't.3.2.bias', 't.3.4.weight', 't.3.4.bias', 't.4.0.weight', 't.4.0.bias', 't.4.2.weight', 't.4.2.bias', 't.4.4.weight', 't.4.4.bias', 't.5.0.weight', 't.5.0.bias', 't.5.2.weight', 't.5.2.bias', 't.5.4.weight', 't.5.4.bias', 't.6.0.weight', 't.6.0.bias', 't.6.2.weight', 't.6.2.bias', 't.6.4.weight', 't.6.4.bias', 't.7.0.weight', 't.7.0.bias', 't.7.2.weig

## Evaluation

In [20]:
@torch.no_grad()
def eval_posterior(bnn, samples, dataloader, name_preffix="", device=device):
    metrics = dict()

    all_predictions = []
    all_targets = []
    all_probabilities = []

    for step_no, (x, y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)

        probs = posterior_predictions(bnn, x, samples=samples)
        probs = torch.mean(probs, dim=0)
        class_predictions = torch.argmax(probs, dim=-1)

        all_targets.append(y)
        all_predictions.append(class_predictions)
        all_probabilities.append(probs)

    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)
    all_probabilities = torch.cat(all_probabilities)

    accuracy = torch.sum(all_predictions == all_targets) / len(all_targets)

    ece_l1 = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="l1",
    )
    ece_l2 = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="l2",
    )
    ece_max = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="max",
    )

    metrics[f"{name_preffix}eval_accuracy"] = accuracy
    metrics[f"{name_preffix}eval_ece_l1"] = ece_l1
    metrics[f"{name_preffix}eval_ece_l2"] = ece_l2
    metrics[f"{name_preffix}eval_ece_max"] = ece_max

    return metrics

## Learning

In [21]:
print("Evaluate before training")

samples, q_nlls = sampler(n_samples=111)
assert not q_nlls.isnan().any(), q_nlls

metrics = eval_posterior(bnn, samples, test_dataloader, device=device)
metrics_str = " ".join([f"{k}={v:.2f}" for k, v in metrics.items()])
print(f"[start] metrics={metrics_str}")
wandb_log(metrics, step=0)

Evaluate before training
[start] metrics=eval_accuracy=0.11 eval_ece_l1=0.04 eval_ece_l2=0.04 eval_ece_max=0.20


In [22]:
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)

In [23]:
total_variational_params = sum(p.numel() for p in variational_params.values())
print(f"#variational params = {total_variational_params}")

#variational params = 34420528


In [24]:
optimized_parameters = list(variational_params.values()) + list(
    pointwise_params.values()
)
optimizer = getattr(torch.optim, cfg.optimizer)(optimized_parameters, lr=cfg.lr)

omega = 1.0
beta = 1.0

In [25]:
start_time = time.time()
for epoch in range(cfg.n_epochs):

    loss_vi_k = []
    KLD_k = []
    log_lik_k = []
    for it, (x, y) in tqdm.tqdm(enumerate(train_dataloader)):
        full2minibatch_ratio = len(train_dataloader.dataset) / len(x)

        optimizer.zero_grad()
        elbo_res = elbo_mc(
            bnn,
            x,
            y,
            log_priors,
            categorical_log_prob,
            sampler,
            cfg.n_posterior_samples,
            full2minibatch_ratio,
        )
        log_lik, KLD = elbo_res["ll"], elbo_res["kl"]
        loss_vi = -(omega * log_lik - beta * KLD)
        loss_vi.backward()

        loss_vi_k.append(loss_vi.detach().cpu().item())
        KLD_k.append(KLD.detach().cpu().item())
        log_lik_k.append(log_lik.detach().cpu().item())

        optimizer.step()

    # reporting
    log_lik = np.mean(log_lik_k)
    KLD = np.mean(KLD_k)
    loss_vi = np.mean(loss_vi_k)

    samples, q_nlls = sampler(n_samples=111)
    assert not q_nlls.isnan().any(), q_nlls

    metrics = eval_posterior(bnn, samples, test_dataloader, device=device)
    metrics_str = " ".join([f"{k}={v:.3f}" for k, v in metrics.items()])
    wandb_log(metrics, step=epoch)

    res_str = f"train: loss={loss_vi:.2f} log_lik={log_lik:.2f} KLD={KLD:.2f}"
    print(
        f"[{time.time()-start_time:.0f}s][epoch={epoch}] {res_str}  / test metrics: {metrics_str}"
    )

1it [01:34, 94.95s/it]


KeyboardInterrupt: 