In [1]:
import numpy as np
import math
import os
from bnn_priors.models.conv_nets import ClassificationConvNet, CorrelatedClassificationConvNet
from bnn_priors import exp_utils

In [2]:
data = exp_utils.get_data("mnist", "cpu")

In [3]:
x_train = data.norm.train_X
y_train = data.norm.train_y

In [4]:
width = 50
depth = 3
weight_prior = exp_utils.get_prior("gaussian")
bias_prior = exp_utils.get_prior("gaussian")
weight_loc = 0.
weight_scale = 2.**0.5
bias_loc = 0.
bias_scale = 1.
scaling_fn = lambda std, dim: std/dim**0.5
weight_prior_params = {}
bias_prior_params = {}

In [5]:
if len(x_train.shape) == 4:
    in_channels = x_train.shape[1]
    img_height = x_train.shape[-2]
else:
    in_channels = 1
    img_height = int(math.sqrt(x_train.shape[-1]))
net = CorrelatedClassificationConvNet(in_channels, img_height, y_train.max()+1, width, depth, softmax_temp=1.,
               prior_w=weight_prior, loc_w=weight_loc, std_w=weight_scale,
               prior_b=bias_prior, loc_b=bias_loc, std_b=bias_scale, scaling_fn=scaling_fn,
               weight_prior_params=weight_prior_params, bias_prior_params=bias_prior_params).to(x_train)

In [6]:
list(net.modules())[-9].weight_prior

ConvCorrelatedNormal()

In [7]:
wp = weight_prior((50, 50, 3, 3), weight_loc, weight_scale)

In [8]:
wp.log_prob()

tensor(-39534.0586, grad_fn=<SumBackward0>)

In [9]:
p = exp_utils.get_prior("convcorrnormal")

In [10]:
wp = p((50, 50, 3, 3), weight_loc, weight_scale, lengthscale=1e-6)

In [11]:
wp.log_prob()

tensor(-39858.3008, grad_fn=<SumBackward0>)

In [12]:
wp._dist_obj().scale_tril

tensor([[1.4142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.4142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.4142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 1.4142, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4142, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4142, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4142, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4142, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4142]])

In [13]:
weight_scale

1.4142135623730951

In [14]:
weight_loc

0.0

## Hierarchical priors

In [33]:
import torch
import torch.distributions as td
from bnn_priors.prior import PositiveImproper, Normal, ConvCorrelatedNormal

In [16]:
scale_prior = PositiveImproper(shape=[], loc=1., scale=1.)

In [23]:
prior_w = Normal(shape=[2,2], loc=0., scale=1.)
prior_w.sample()
prior_w.p

Parameter containing:
tensor([[-0.1467,  0.2487],
        [ 1.2664,  0.1471]], requires_grad=True)

In [25]:
prior_w = Normal(shape=[2,2], loc=0., scale=scale_prior)
prior_w.sample()
prior_w.p

Parameter containing:
tensor([[ 0.8129,  0.2287],
        [-1.7233,  1.1422]], requires_grad=True)

In [36]:
td.Normal(loc=0., scale=scale_prior).sample()

ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.