# Learning posteriors for BNN on MNIST data

## Prepare

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install torch torchvision torchmetrics tqdm pandas matplotlib
!pip install "numpy<2.0" 

## Imports

In [4]:
import numpy as np

import torch
import torch.nn as nn
from torchmetrics.functional.classification import multiclass_calibration_error

import argparse
import sys
import time
import tqdm
import gc

In [5]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
import sys

sys.path.append("../")

In [6]:
import reparameterized
from reparameterized import sampling

from reparameterized.likelihoods import (
    categorical_posterior_probs as posterior_predictions,
)
from reparameterized.likelihoods import categorical_log_prob
from reparameterized.bnn_wrapper import elbo_mc



## Config

In [7]:
if "ipykernel" in sys.argv[0]:
    sys.argv = [""]

In [None]:
parser = argparse.ArgumentParser(description="Reparameterized: BNN+MNIST")

parser.add_argument("--name", type=str, default=None)

parser.add_argument("--dataset", type=str, default="mnist", choices=("mnist"))
parser.add_argument("--batch_size", type=int, default=1024)

parser.add_argument("--optimizer", type=str, default="Adam")
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--n_posterior_samples", type=int, default=17)
parser.add_argument("--n_epochs", type=int, default=100)

parser.add_argument(
    "--posterior_model_architecture",
    type=str,
    default="svd_rnvp_small_rezero",
    choices=(
        "rnvp_rezero",
        "rnvp",
        "rnvp_rezero_small",
        "rnvp_small",
        "factorized_gaussian",
        "factorized_gaussian_rezero",
        "gaussian_tril",
        "gaussian_tril_rezero",
        "gaussian_full",
        "gaussian_full_rezero",
        "svd_rnvp_small_rezero",
    ),
)

parser.add_argument(
    "--distributional_parameters", type=str, nargs="+", default=["last_layer.weight"]
)

parser.add_argument("--seed", type=int, default=1863)
parser.add_argument("--wandb", default=False, action="store_true")

_StoreTrueAction(option_strings=['--wandb'], dest='wandb', nargs=0, const=True, default=False, type=None, choices=None, required=False, help=None, metavar=None)

In [9]:
cfg = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
if cfg.name is None:
    distributional_parameters_str = "_".join(cfg.distributional_parameters)
    cfg.name = f"{cfg.dataset}_{cfg.posterior_model_architecture}_seed{cfg.seed}_N{cfg.n_posterior_samples}_E{cfg.n_epochs}_lr{cfg.lr}_P{distributional_parameters_str}_{cfg.optimizer}"

In [None]:
def wandb_log(*args, **kwargs):
    pass


if cfg.wandb:
    import wandb

    wandb.init(project="Reparameterized: BNN+MNIST", config=cfg.__dict__, name=cfg.name)

    def wandb_log(*args, **kwargs):
        wandb.log(*args, **kwargs)

In [12]:
print(f"cfg = {cfg}")

cfg = Namespace(name='mnist_svd_rnvp_small_rezero_seed1863_N17_E100_lr0.001_Plast_layer.weight_Adam', dataset='mnist', batch_size=1024, optimizer='Adam', lr=0.001, n_posterior_samples=17, n_epochs=100, posterior_model_architecture='svd_rnvp_small_rezero', distributional_parameters=['last_layer.weight'], seed=1863, wandb=False)


## Data and Model

In [13]:
def get_mnist_dataloaders(batch_size=128):
    mnist_train_dataset = torchvision.datasets.MNIST(
        root="data",
        train=True,
        transform=transforms.ToTensor(),
        download=True,
    )

    mnist_test_dataset = torchvision.datasets.MNIST(
        root="data",
        train=False,
        transform=transforms.ToTensor(),
    )

    mnist_ood_dataset = torchvision.datasets.FashionMNIST(
        root="data",
        train=False,
        transform=transforms.ToTensor(),
        download=True,
    )

    dataloader_train = DataLoader(
        mnist_train_dataset, batch_size=batch_size, shuffle=True
    )
    dataloader_test = DataLoader(
        mnist_test_dataset, batch_size=batch_size, shuffle=False
    )
    dataloader_ood = DataLoader(mnist_ood_dataset, batch_size=batch_size, shuffle=False)

    return dataloader_train, dataloader_test, dataloader_ood

In [14]:
class BNNMnist(nn.Module):
    def __init__(
        self,
        in_dim=784,
        out_dim=10,
        hid_dim=128,
        num_layers=2,
        device=torch.device("cuda"),
    ):
        super().__init__()

        self.device = device

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hid_dim = hid_dim
        self.num_layers = num_layers

        self.common_layers = nn.Sequential(
            nn.Linear(self.in_dim, self.hid_dim),
            *[
                nn.Sequential(nn.Linear(self.hid_dim, self.hid_dim), nn.ELU())
                for _ in range(self.num_layers)
            ],
        ).to(self.device)

        self.last_layer = nn.Linear(self.hid_dim, self.out_dim).to(self.device)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.common_layers(x)
        x = self.last_layer(x)
        return x

In [15]:
if cfg.dataset == "mnist":
    train_dataloader, test_dataloader, _ = get_mnist_dataloaders(
        batch_size=cfg.batch_size
    )
    bnn = BNNMnist(in_dim=784, out_dim=10, hid_dim=128, num_layers=2, device=device)

else:
    raise ValueError(f"Dataset={cfg.dataset} not supported!")

## Priors and Posteriors

In [16]:
pointwise_params = {}
distributional_params = {}

print("Model parameters:")
for n, p in bnn.named_parameters():

    distributional = any(
        (param_selector in n) for param_selector in cfg.distributional_parameters
    )

    print(f" - {n}: {p.shape} (grad={p.requires_grad}) (distribution={distributional})")

    if distributional:
        distributional_params[n] = p
    else:
        pointwise_params[n] = p

Model parameters:
 - common_layers.0.weight: torch.Size([128, 784]) (grad=True) (distribution=False)
 - common_layers.0.bias: torch.Size([128]) (grad=True) (distribution=False)
 - common_layers.1.0.weight: torch.Size([128, 128]) (grad=True) (distribution=False)
 - common_layers.1.0.bias: torch.Size([128]) (grad=True) (distribution=False)
 - common_layers.2.0.weight: torch.Size([128, 128]) (grad=True) (distribution=False)
 - common_layers.2.0.bias: torch.Size([128]) (grad=True) (distribution=False)
 - last_layer.weight: torch.Size([10, 128]) (grad=True) (distribution=True)
 - last_layer.bias: torch.Size([10]) (grad=True) (distribution=False)


In [17]:
def create_gaussian_basic_prior(p):
    p = p.flatten()
    return torch.distributions.MultivariateNormal(
        loc=torch.zeros_like(p), covariance_matrix=torch.diag(torch.ones_like(p))
    )


priors = {n: create_gaussian_basic_prior(p) for n, p in distributional_params.items()}


def log_priors(samples):
    return sum(
        priors[n].log_prob(p.flatten()) for n, p in samples.items()
    )  # sum over all parameters

In [18]:
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)

In [19]:
# Each parameter gets its own sampler
parameter2sampler, variational_params, aux_objs = sampling.create_independent_samplers(
    distributional_params, cfg.posterior_model_architecture
)
sampler = reparameterized.parameter_samplers_to_joint_sampler(parameter2sampler)

# All parameters are put together and use a joint sampler
# sampler, variational_params, aux_objs = sampling.create_joint_sampler(
#     distributional_params, cfg.posterior_model_architecture
# )

print(f"Posterior={cfg.posterior_model_architecture}:")
print(f" - target_params = {list(distributional_params.keys())}")
print(f" - sampler = {sampler}")
print(f" - variational_params = {list(variational_params.keys())}")
print(f" - aux_objs = {list(aux_objs.keys())}")

Posterior=svd_rnvp_small_rezero:
 - target_params = ['last_layer.weight']
 - sampler = <function parameter_samplers_to_joint_sampler.<locals>.single_sampler at 0x13802c9a0>
 - variational_params = ['last_layer.weight.alpha', 'last_layer.weight.beta', 'last_layer.weight.t.0.0.weight', 'last_layer.weight.t.0.0.bias', 'last_layer.weight.t.0.2.weight', 'last_layer.weight.t.0.2.bias', 'last_layer.weight.t.0.4.weight', 'last_layer.weight.t.0.4.bias', 'last_layer.weight.t.1.0.weight', 'last_layer.weight.t.1.0.bias', 'last_layer.weight.t.1.2.weight', 'last_layer.weight.t.1.2.bias', 'last_layer.weight.t.1.4.weight', 'last_layer.weight.t.1.4.bias', 'last_layer.weight.t.2.0.weight', 'last_layer.weight.t.2.0.bias', 'last_layer.weight.t.2.2.weight', 'last_layer.weight.t.2.2.bias', 'last_layer.weight.t.2.4.weight', 'last_layer.weight.t.2.4.bias', 'last_layer.weight.t.3.0.weight', 'last_layer.weight.t.3.0.bias', 'last_layer.weight.t.3.2.weight', 'last_layer.weight.t.3.2.bias', 'last_layer.weight.t.3.

## Evaluation

In [20]:
@torch.no_grad()
def eval_posterior(bnn, samples, dataloader, name_preffix="", device=device):
    metrics = dict()

    all_predictions = []
    all_targets = []
    all_probabilities = []

    for step_no, (x, y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)

        probs = posterior_predictions(bnn, x, samples=samples)
        probs = torch.mean(probs, dim=0)
        class_predictions = torch.argmax(probs, dim=-1)

        all_targets.append(y)
        all_predictions.append(class_predictions)
        all_probabilities.append(probs)

    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)
    all_probabilities = torch.cat(all_probabilities)

    accuracy = torch.sum(all_predictions == all_targets) / len(all_targets)

    ece_l1 = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="l1",
    )
    ece_l2 = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="l2",
    )
    ece_max = multiclass_calibration_error(
        preds=all_probabilities,
        target=all_targets,
        num_classes=10,
        n_bins=15,
        norm="max",
    )

    metrics[f"{name_preffix}eval_accuracy"] = accuracy
    metrics[f"{name_preffix}eval_ece_l1"] = ece_l1
    metrics[f"{name_preffix}eval_ece_l2"] = ece_l2
    metrics[f"{name_preffix}eval_ece_max"] = ece_max

    return metrics

## Learning

In [21]:
print("Evaluate before training")

samples, q_nlls = sampler(n_samples=111)
assert not q_nlls.isnan().any(), q_nlls

metrics = eval_posterior(bnn, samples, test_dataloader, device=device)
metrics_str = " ".join([f"{k}={v:.2f}" for k, v in metrics.items()])
print(f"[start] metrics={metrics_str}")
wandb_log(metrics, step=0)

Evaluate before training
[start] metrics=eval_accuracy=0.08 eval_ece_l1=0.02 eval_ece_l2=0.02 eval_ece_max=0.02


In [22]:
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)

In [23]:
total_variational_params = sum(p.numel() for p in variational_params.values())
print(f"#variational params = {total_variational_params}")

#variational params = 471856


In [24]:
optimized_parameters = list(variational_params.values()) + list(
    pointwise_params.values()
)
optimizer = getattr(torch.optim, cfg.optimizer)(optimized_parameters, lr=cfg.lr)

omega = 1.0
beta = 1.0

In [25]:
start_time = time.time()
for epoch in range(cfg.n_epochs):

    loss_vi_k = []
    KLD_k = []
    log_lik_k = []
    for it, (x, y) in tqdm.tqdm(enumerate(train_dataloader)):
        full2minibatch_ratio = len(train_dataloader.dataset) / len(x)

        optimizer.zero_grad()
        elbo_res = elbo_mc(
            bnn,
            x,
            y,
            log_priors,
            categorical_log_prob,
            sampler,
            cfg.n_posterior_samples,
            full2minibatch_ratio,
        )
        log_lik, KLD = elbo_res["ll"], elbo_res["kl"]
        loss_vi = -(omega * log_lik - beta * KLD)
        loss_vi.backward()

        loss_vi_k.append(loss_vi.detach().cpu().item())
        KLD_k.append(KLD.detach().cpu().item())
        log_lik_k.append(log_lik.detach().cpu().item())

        optimizer.step()

    # reporting
    log_lik = np.mean(log_lik_k)
    KLD = np.mean(KLD_k)
    loss_vi = np.mean(loss_vi_k)

    samples, q_nlls = sampler(n_samples=111)
    assert not q_nlls.isnan().any(), q_nlls

    metrics = eval_posterior(bnn, samples, test_dataloader, device=device)
    metrics_str = " ".join([f"{k}={v:.3f}" for k, v in metrics.items()])
    wandb_log(metrics, step=epoch)

    res_str = f"train: loss={loss_vi:.2f} log_lik={log_lik:.2f} KLD={KLD:.2f}"
    print(
        f"[{time.time()-start_time:.0f}s][epoch={epoch}] {res_str}  / test metrics: {metrics_str}"
    )

59it [00:13,  4.30it/s]


[ 16s][epoch=0] train: loss=114403.42 log_lik=-113206.98 KLD=1196.44  / test metrics: eval_accuracy=0.843 eval_ece_l1=0.046 eval_ece_l2=0.059 eval_ece_max=0.140


59it [00:13,  4.50it/s]


[ 32s][epoch=1] train: loss=25516.52 log_lik=-23205.41 KLD=2311.11  / test metrics: eval_accuracy=0.920 eval_ece_l1=0.026 eval_ece_l2=0.034 eval_ece_max=0.080


59it [00:16,  3.55it/s]


[ 52s][epoch=2] train: loss=19215.60 log_lik=-16992.15 KLD=2223.45  / test metrics: eval_accuracy=0.927 eval_ece_l1=0.018 eval_ece_l2=0.026 eval_ece_max=0.262


59it [00:16,  3.54it/s]


[ 73s][epoch=3] train: loss=15830.50 log_lik=-13666.61 KLD=2163.89  / test metrics: eval_accuracy=0.941 eval_ece_l1=0.016 eval_ece_l2=0.024 eval_ece_max=0.126


59it [00:17,  3.31it/s]


[ 95s][epoch=4] train: loss=13075.53 log_lik=-11032.54 KLD=2042.99  / test metrics: eval_accuracy=0.953 eval_ece_l1=0.015 eval_ece_l2=0.026 eval_ece_max=0.227


59it [00:15,  3.75it/s]


[ 114s][epoch=5] train: loss=11140.46 log_lik=-9165.95 KLD=1974.50  / test metrics: eval_accuracy=0.956 eval_ece_l1=0.010 eval_ece_l2=0.021 eval_ece_max=0.490


59it [00:15,  3.74it/s]


[ 133s][epoch=6] train: loss=10066.40 log_lik=-8174.38 KLD=1892.02  / test metrics: eval_accuracy=0.961 eval_ece_l1=0.004 eval_ece_l2=0.016 eval_ece_max=0.731


59it [00:15,  3.83it/s]


[ 152s][epoch=7] train: loss=9223.14 log_lik=-7400.09 KLD=1823.05  / test metrics: eval_accuracy=0.964 eval_ece_l1=0.007 eval_ece_l2=0.017 eval_ece_max=0.167


59it [00:16,  3.58it/s]


[ 172s][epoch=8] train: loss=8400.91 log_lik=-6632.11 KLD=1768.80  / test metrics: eval_accuracy=0.966 eval_ece_l1=0.011 eval_ece_l2=0.022 eval_ece_max=0.257


59it [00:15,  3.71it/s]


[ 191s][epoch=9] train: loss=7658.09 log_lik=-5915.29 KLD=1742.80  / test metrics: eval_accuracy=0.969 eval_ece_l1=0.003 eval_ece_l2=0.011 eval_ece_max=0.243


59it [00:16,  3.64it/s]


[ 211s][epoch=10] train: loss=7188.75 log_lik=-5468.07 KLD=1720.68  / test metrics: eval_accuracy=0.969 eval_ece_l1=0.004 eval_ece_l2=0.010 eval_ece_max=0.059


59it [00:15,  3.85it/s]


[ 230s][epoch=11] train: loss=6538.89 log_lik=-4837.72 KLD=1701.17  / test metrics: eval_accuracy=0.969 eval_ece_l1=0.003 eval_ece_l2=0.012 eval_ece_max=0.427


59it [00:15,  3.91it/s]


[ 248s][epoch=12] train: loss=6413.94 log_lik=-4741.52 KLD=1672.42  / test metrics: eval_accuracy=0.969 eval_ece_l1=0.005 eval_ece_l2=0.019 eval_ece_max=0.311


59it [00:15,  3.81it/s]


[ 267s][epoch=13] train: loss=5893.69 log_lik=-4229.39 KLD=1664.30  / test metrics: eval_accuracy=0.972 eval_ece_l1=0.004 eval_ece_l2=0.016 eval_ece_max=0.242


59it [00:15,  3.70it/s]


[ 287s][epoch=14] train: loss=5628.82 log_lik=-3989.40 KLD=1639.42  / test metrics: eval_accuracy=0.973 eval_ece_l1=0.003 eval_ece_l2=0.019 eval_ece_max=0.697


59it [00:15,  3.70it/s]


[ 306s][epoch=15] train: loss=5185.56 log_lik=-3531.89 KLD=1653.67  / test metrics: eval_accuracy=0.973 eval_ece_l1=0.003 eval_ece_l2=0.011 eval_ece_max=0.096


59it [00:14,  3.95it/s]


[ 325s][epoch=16] train: loss=4995.27 log_lik=-3348.72 KLD=1646.55  / test metrics: eval_accuracy=0.972 eval_ece_l1=0.003 eval_ece_l2=0.011 eval_ece_max=0.289


59it [00:15,  3.85it/s]


[ 344s][epoch=17] train: loss=4889.33 log_lik=-3264.36 KLD=1624.98  / test metrics: eval_accuracy=0.974 eval_ece_l1=0.004 eval_ece_l2=0.014 eval_ece_max=0.312


59it [00:15,  3.78it/s]


[ 363s][epoch=18] train: loss=4617.02 log_lik=-3014.94 KLD=1602.08  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.007 eval_ece_l2=0.017 eval_ece_max=0.679


59it [00:16,  3.62it/s]


[ 383s][epoch=19] train: loss=4315.06 log_lik=-2695.53 KLD=1619.53  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.005 eval_ece_l2=0.019 eval_ece_max=0.345


59it [00:15,  3.75it/s]


[ 402s][epoch=20] train: loss=4051.30 log_lik=-2442.46 KLD=1608.84  / test metrics: eval_accuracy=0.974 eval_ece_l1=0.009 eval_ece_l2=0.024 eval_ece_max=0.288


59it [00:15,  3.88it/s]


[ 420s][epoch=21] train: loss=3826.44 log_lik=-2209.97 KLD=1616.46  / test metrics: eval_accuracy=0.976 eval_ece_l1=0.004 eval_ece_l2=0.012 eval_ece_max=0.295


59it [00:16,  3.67it/s]


[ 440s][epoch=22] train: loss=3698.02 log_lik=-2082.89 KLD=1615.14  / test metrics: eval_accuracy=0.974 eval_ece_l1=0.008 eval_ece_l2=0.024 eval_ece_max=0.311


59it [00:15,  3.71it/s]


[ 459s][epoch=23] train: loss=3544.90 log_lik=-1948.49 KLD=1596.41  / test metrics: eval_accuracy=0.976 eval_ece_l1=0.009 eval_ece_l2=0.028 eval_ece_max=0.703


59it [00:15,  3.74it/s]


[ 479s][epoch=24] train: loss=3496.47 log_lik=-1909.08 KLD=1587.38  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.008 eval_ece_l2=0.023 eval_ece_max=0.310


59it [00:18,  3.15it/s]


[ 502s][epoch=25] train: loss=3346.60 log_lik=-1764.56 KLD=1582.04  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.008 eval_ece_l2=0.023 eval_ece_max=0.249


59it [00:16,  3.69it/s]


[ 521s][epoch=26] train: loss=3104.50 log_lik=-1521.97 KLD=1582.54  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.008 eval_ece_l2=0.023 eval_ece_max=0.295


59it [00:15,  3.69it/s]


[ 541s][epoch=27] train: loss=3037.86 log_lik=-1452.44 KLD=1585.42  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.010 eval_ece_l2=0.027 eval_ece_max=0.311


59it [00:15,  3.72it/s]


[ 560s][epoch=28] train: loss=2898.02 log_lik=-1317.72 KLD=1580.30  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.010 eval_ece_l2=0.024 eval_ece_max=0.312


59it [00:17,  3.35it/s]


[ 581s][epoch=29] train: loss=2710.22 log_lik=-1121.02 KLD=1589.19  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.010 eval_ece_l2=0.025 eval_ece_max=0.680


59it [00:19,  3.06it/s]


[ 605s][epoch=30] train: loss=2785.43 log_lik=-1223.95 KLD=1561.48  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.010 eval_ece_l2=0.027 eval_ece_max=0.317


59it [00:17,  3.37it/s]


[ 626s][epoch=31] train: loss=2713.29 log_lik=-1159.94 KLD=1553.35  / test metrics: eval_accuracy=0.978 eval_ece_l1=0.009 eval_ece_l2=0.026 eval_ece_max=0.435


59it [00:15,  3.78it/s]


[ 645s][epoch=32] train: loss=2514.11 log_lik=-958.11 KLD=1555.99  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.010 eval_ece_l2=0.030 eval_ece_max=0.691


59it [00:16,  3.58it/s]


[ 665s][epoch=33] train: loss=2474.59 log_lik=-925.63 KLD=1548.96  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.011 eval_ece_l2=0.030 eval_ece_max=0.275


59it [00:16,  3.65it/s]


[ 684s][epoch=34] train: loss=2402.82 log_lik=-851.34 KLD=1551.48  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.010 eval_ece_l2=0.025 eval_ece_max=0.371


59it [00:15,  3.73it/s]


[ 704s][epoch=35] train: loss=2303.36 log_lik=-759.02 KLD=1544.34  / test metrics: eval_accuracy=0.975 eval_ece_l1=0.014 eval_ece_l2=0.035 eval_ece_max=0.211


59it [00:16,  3.56it/s]


[ 724s][epoch=36] train: loss=2268.20 log_lik=-726.59 KLD=1541.60  / test metrics: eval_accuracy=0.974 eval_ece_l1=0.013 eval_ece_l2=0.037 eval_ece_max=0.686


59it [00:16,  3.55it/s]


[ 745s][epoch=37] train: loss=2181.51 log_lik=-651.08 KLD=1530.42  / test metrics: eval_accuracy=0.979 eval_ece_l1=0.011 eval_ece_l2=0.030 eval_ece_max=0.688


59it [00:17,  3.33it/s]


[ 766s][epoch=38] train: loss=2218.16 log_lik=-695.14 KLD=1523.01  / test metrics: eval_accuracy=0.978 eval_ece_l1=0.011 eval_ece_l2=0.026 eval_ece_max=0.383


59it [00:16,  3.65it/s]


[ 786s][epoch=39] train: loss=2053.38 log_lik=-527.98 KLD=1525.40  / test metrics: eval_accuracy=0.979 eval_ece_l1=0.010 eval_ece_l2=0.025 eval_ece_max=0.213


59it [00:16,  3.63it/s]


[ 805s][epoch=40] train: loss=1975.15 log_lik=-459.21 KLD=1515.94  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.012 eval_ece_l2=0.032 eval_ece_max=0.302


59it [00:15,  3.74it/s]


[ 825s][epoch=41] train: loss=1923.16 log_lik=-414.81 KLD=1508.35  / test metrics: eval_accuracy=0.978 eval_ece_l1=0.012 eval_ece_l2=0.035 eval_ece_max=0.376


59it [00:16,  3.67it/s]


[ 845s][epoch=42] train: loss=1859.57 log_lik=-349.07 KLD=1510.50  / test metrics: eval_accuracy=0.979 eval_ece_l1=0.012 eval_ece_l2=0.034 eval_ece_max=0.325


59it [00:17,  3.34it/s]


[ 866s][epoch=43] train: loss=1893.50 log_lik=-392.41 KLD=1501.09  / test metrics: eval_accuracy=0.977 eval_ece_l1=0.012 eval_ece_l2=0.036 eval_ece_max=0.355


59it [00:16,  3.58it/s]


[ 886s][epoch=44] train: loss=1783.03 log_lik=-291.64 KLD=1491.39  / test metrics: eval_accuracy=0.978 eval_ece_l1=0.012 eval_ece_l2=0.036 eval_ece_max=0.410


59it [00:15,  3.75it/s]


[ 906s][epoch=45] train: loss=1790.20 log_lik=-299.31 KLD=1490.89  / test metrics: eval_accuracy=0.978 eval_ece_l1=0.012 eval_ece_l2=0.030 eval_ece_max=0.322


13it [00:04,  3.14it/s]


KeyboardInterrupt: 