In [1]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn, utils, func
import torchopt
import posteriors

dataset = MNIST(root="./data", transform=ToTensor(), download=True)
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
num_data = len(dataset)

classifier = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 10))
params = dict(classifier.named_parameters())


def log_posterior(params, batch):
    images, labels = batch
    images = images.view(images.size(0), -1)
    output = func.functional_call(classifier, params, images)
    log_post_val = (
        -nn.functional.cross_entropy(output, labels)
        + posteriors.diag_normal_log_prob(params) / num_data
    )
    return log_post_val, output


transform = posteriors.vi.diag.build(
    log_posterior, torchopt.adam(), temperature=1 / num_data
)  # Can swap out for any posteriors algorithm

state = transform.init(params)

for batch in train_loader:
    state, aux = transform.update(state, batch)

W0901 01:12:36.424000 23648 Lib\site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
  from optree.integration.torch import tree_ravel
100%|██████████| 9.91M/9.91M [00:01<00:00, 6.35MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 210kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 2.16MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.51MB/s]


In [4]:
state

VIDiagState(
    log_sd_diag=NonTensorData(data={'0.weight': tensor(  ...  _fn=<AddBackward0>)}, batch_size=torch.Size([]), device=None),
    nelbo=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
    opt_state=NonTensorData(data=(ScaleByAdamState(mu  ...  75)]), EmptyState()), batch_size=torch.Size([]), device=None),
    params=NonTensorData(data={'0.weight': tensor(  ...  _fn=<AddBackward0>)}, batch_size=torch.Size([]), device=None),
    step=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [5]:
params

{'0.weight': Parameter containing:
 tensor([[-0.0052,  0.0240,  0.0209,  ...,  0.0121,  0.0262, -0.0279],
         [-0.0175,  0.0174, -0.0278,  ...,  0.0083,  0.0111,  0.0032],
         [-0.0202,  0.0014, -0.0341,  ...,  0.0337,  0.0331, -0.0287],
         ...,
         [-0.0056, -0.0087, -0.0113,  ...,  0.0127,  0.0035, -0.0339],
         [-0.0041,  0.0234, -0.0162,  ..., -0.0269,  0.0226, -0.0132],
         [-0.0045, -0.0061,  0.0332,  ...,  0.0236, -0.0133,  0.0164]],
        requires_grad=True),
 '0.bias': Parameter containing:
 tensor([ 0.0225,  0.0145, -0.0231,  0.0010,  0.0088, -0.0040,  0.0124,  0.0150,
         -0.0170,  0.0101, -0.0312,  0.0069,  0.0251,  0.0317,  0.0264, -0.0161,
         -0.0174,  0.0206, -0.0267, -0.0046,  0.0246, -0.0042, -0.0082,  0.0013,
          0.0116, -0.0339, -0.0340, -0.0124,  0.0152, -0.0239,  0.0214,  0.0008,
          0.0182, -0.0315,  0.0104,  0.0230, -0.0215, -0.0023,  0.0154,  0.0176,
          0.0330,  0.0032,  0.0150,  0.0003, -0.0224, -0.

In [None]:
state.params

{'0.weight': tensor([[ 1.2148e-11,  4.4999e-11,  2.4448e-12,  ..., -5.6119e-11,
          -3.6033e-11,  5.6699e-12],
         [ 2.7890e-11, -2.8715e-12, -1.1627e-11,  ..., -7.4761e-14,
          -3.8051e-11,  1.3771e-05],
         [ 7.3005e-11, -1.2755e-10,  1.4253e-12,  ..., -2.6487e-11,
           1.0827e-11, -5.9081e-11],
         ...,
         [-7.1553e-10, -1.7301e-05, -2.5226e-11,  ..., -1.7831e-12,
           1.6229e-06, -7.3980e-11],
         [ 2.4941e-04,  2.6498e-11, -1.3135e-12,  ..., -1.0180e-10,
          -1.0406e-11, -1.2345e-10],
         [-1.2255e-09, -5.8497e-08, -2.2756e-11,  ...,  3.8858e-11,
          -8.6702e-12, -6.2145e-12]], grad_fn=<AddBackward0>),
 '0.bias': tensor([-0.2730, -0.2983, -0.3322, -0.3118, -0.3095, -0.3152, -0.2916, -0.2834,
         -0.3261, -0.3092, -0.3483, -0.2879, -0.2817, -0.2944, -0.2982, -0.3312,
         -0.3203, -0.2900, -0.3435, -0.3187, -0.2641, -0.2929, -0.3270, -0.2912,
         -0.2935, -0.3539, -0.3218, -0.3300, -0.2803, -0.3293, -0

In [8]:
import torch

In [10]:
# evaluate the model
classifier.eval()
total_loss_state = 0.0
total_loss_original = 0.0
num_batches = 0

with torch.no_grad():
    for batch in train_loader:
        images, labels = batch
        images = images.view(images.size(0), -1)
        
        # Loss for state.params
        output_state = func.functional_call(classifier, state.params, images)
        loss_state = nn.functional.cross_entropy(output_state, labels)
        total_loss_state += loss_state.item()
        
        # Loss for original params
        output_original = func.functional_call(classifier, params, images)
        loss_original = nn.functional.cross_entropy(output_original, labels)
        total_loss_original += loss_original.item()
        
        num_batches += 1

avg_loss_state = total_loss_state / num_batches
avg_loss_original = total_loss_original / num_batches

print(f"Average Loss (state.params): {avg_loss_state:.4f}")
print(f"Average Loss (original params): {avg_loss_original:.4f}")


Average Loss (state.params): 2.3025
Average Loss (original params): 2.3074
