In [1]:
import torch
import torch.distributions as td
from gpytorch.utils.transforms import inv_softplus
import numpy as np

from bnn_priors.models import DenseNet
from bnn_priors.prior import Prior, LocScale, Normal, Gamma, Uniform, Laplace

In [2]:
from bnn_priors.prior import NormalGamma

## Debug priors on CUDA

In [3]:
torch.random.manual_seed(1337)

prior_n = Normal([2,3], 0., 1.)
prior_l = Laplace([2,3], 0., 1.)
prior_g = Gamma([2,3], 1., 1.)
prior_u = Uniform([2,3], 0., 1.)

In [4]:
print(prior_n.log_prob())
print(prior_l.log_prob())
print(prior_g.log_prob())
print(prior_u.log_prob())

tensor(-11.9458, grad_fn=<SumBackward0>)
tensor(-9.9780, grad_fn=<SumBackward0>)
tensor(-6.4006, grad_fn=<SumBackward0>)
-0.0


In [5]:
prior_n.cuda();
prior_l.cuda();
prior_g.cuda();
prior_u.cuda();

In [6]:
print(prior_n.log_prob())
print(prior_l.log_prob())
print(prior_g.log_prob())
print(prior_u.log_prob())

tensor(-11.9458, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-9.9780, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-6.4006, grad_fn=<SumBackward0>)
-0.0


Doesn't work for the Gamma prior...

## NormalGamma prior

In [7]:
# class NormalGamma(Normal):
#     def __init__(self, shape, loc, scale, rate=1., gradient_scaling=0.001):
#         scale_prior = Gamma(shape=[], concentration=scale, rate=rate)
#         super().__init__(shape, loc, scale_prior)
#         self.scale.p.register_hook(self.hook)
#         self.gradient_scaling = gradient_scaling
        
#     def log_prob(self):
#         return super().log_prob() + self.scale.log_prob()
    
#     def hook(self, grad):
#         return grad*self.gradient_scaling

## Horseshoe prior

In [78]:
class HalfCauchy(Prior):
    _dist = td.HalfCauchy
    def __init__(self, shape, scale=1., multiplyer=1.):
        super().__init__(shape, scale)
        self.multiplyer = multiplyer

    def _sample_value(self, shape: torch.Size):
        x = super()._sample_value(shape)
        return inv_softplus(x)

    def forward(self):
        return torch.nn.functional.softplus(self.p) * self.multiplyer

    def log_prob(self):
        return self._dist_obj().log_prob(self()).sum()

In [79]:
class Horseshoe(Normal):
    def __init__(self, shape, loc, scale, hyperscale=1., gradient_clip=1.):
        scale_prior = HalfCauchy(shape=[], scale=hyperscale, multiplyer=scale)
        with torch.no_grad():
            scale_prior.p.data = inv_softplus(torch.tensor(1.))
        super().__init__(shape, loc, scale_prior)
        self.scale.p.register_hook(self.hook)
        self.clip = gradient_clip
        
    def log_prob(self):
        # TODO: it seems like the log prob is too high compared to the other loss terms...
        return super().log_prob() + self.scale.log_prob()
    
    def hook(self, grad):
        # TODO: This somehow affects the downstream gradients of the parameters, which it shouldn't
        # It should only affect the actual scale.p parameter
        return torch.clamp(grad, -self.clip, self.clip)

In [80]:
prior_hs = Horseshoe([2,3], 0., 1.)

In [87]:
prior_hs.log_prob()

tensor(-10.5488, grad_fn=<AddBackward0>)

In [88]:
prior_hs.cuda()

Horseshoe(
  (scale): HalfCauchy()
)

In [90]:
prior_hs.log_prob()

tensor(-10.5488, device='cuda:0', grad_fn=<AddBackward0>)

In [91]:
prior_hs.scale.log_prob()

tensor(-1.1447, grad_fn=<SumBackward0>)

## Test hierarchical priors

In [8]:
torch.random.manual_seed(1337)

prior_n = Normal([12,10], 0., 1.)
prior_ng = NormalGamma([12,10], 0., 1.)

In [9]:
prior_n.log_prob()

tensor(-170.5845, grad_fn=<SumBackward0>)

In [10]:
prior_ng.log_prob()

tensor(-176.1476, grad_fn=<AddBackward0>)

In [11]:
prior_ng.scale.p.grad

In [12]:
prior_ng.log_prob().backward()

In [13]:
prior_ng.scale.p.grad

tensor(1.)

In [14]:
prior_n.p.shape

torch.Size([12, 10])

In [15]:
prior_ng.p.shape

torch.Size([12, 10])

In [16]:
prior_ng.scale.p

Parameter containing:
tensor(0.5413, requires_grad=True)

In [17]:
model = DenseNet(32, 10, 50, prior_w=Normal)

In [18]:
model.log_prior()

tensor(364.1306, grad_fn=<AddBackward0>)

In [19]:
model(torch.randn([4,32])).mean

tensor([[ 2.1409,  0.1275, -0.7249,  1.2202, -0.8064, -3.7940, -2.1093,  2.6775,
          0.3134, -0.4981],
        [ 3.4943,  0.0230, -1.6881,  1.2467, -0.3613, -0.2536, -1.3516, -0.0400,
          0.4567, -1.9111],
        [ 0.6680,  1.5745,  2.3277,  2.6311, -0.7459, -4.7640, -0.7433,  1.1683,
          1.6391, -0.9630],
        [ 0.9653,  0.7864,  1.8622,  2.1716, -1.0020, -4.4027,  1.5087,  0.6254,
          2.7809,  0.2992]], grad_fn=<AddmmBackward>)

In [20]:
model.net[0].weight_prior.p

Parameter containing:
tensor([[-0.1778,  0.2452,  0.3774,  ...,  0.1315, -0.2403,  0.3801],
        [ 0.2772, -0.4970,  0.4337,  ..., -0.4379, -0.1111,  0.2818],
        [-0.5138, -0.1108,  0.3719,  ..., -0.0855, -0.2493,  0.0977],
        ...,
        [ 0.0446, -0.2535,  0.6553,  ..., -0.0486,  0.4600, -0.0555],
        [-0.2814,  0.2355,  0.2276,  ...,  0.2829,  0.0587,  0.3589],
        [ 0.1625, -0.0182,  0.5648,  ..., -0.0262,  0.1577,  0.4172]],
       requires_grad=True)

In [21]:
model.log_prior().backward()

In [22]:
model = DenseNet(32, 10, 50, prior_w=NormalGamma)

In [23]:
model.log_prior()

tensor(311.8788, grad_fn=<AddBackward0>)

In [24]:
model(torch.randn([4,32])).mean

tensor([[-3.7128, -0.0110,  1.3462, -2.3352,  0.2479, -3.1084,  0.1465,  1.3793,
         -0.3899,  1.1232],
        [-3.1499,  2.5128,  2.5549, -0.0200, -1.5431, -3.4215,  0.1247,  0.5477,
         -0.8605,  0.8787],
        [-3.7309,  0.5399,  1.8505, -1.3321, -0.6675,  0.1075,  1.9791,  0.1731,
         -1.0120,  0.3939],
        [-1.4584,  3.0402,  3.1238, -0.1705, -0.9159, -3.5089, -1.3441, -0.6013,
          0.5770,  0.0358]], grad_fn=<AddmmBackward>)

In [25]:
model.net[0].weight_prior.p

Parameter containing:
tensor([[-0.0391, -0.0748,  0.4394,  ...,  0.0997, -0.2732, -0.4596],
        [ 0.1924, -0.4597,  0.6601,  ...,  0.2342,  0.0754, -0.1371],
        [-0.2355,  0.1034,  0.2625,  ...,  0.1294, -0.4172,  0.3931],
        ...,
        [ 0.5610, -0.2617,  0.1028,  ..., -0.0586, -0.5406, -0.1236],
        [ 0.1791, -0.2281,  0.2730,  ...,  0.0107,  0.4770,  0.1286],
        [ 0.0903, -0.2282,  0.1181,  ...,  0.1988,  0.1022, -0.1525]],
       requires_grad=True)

In [26]:
model.log_prior().backward()

In [27]:
model.cuda()

RegressionModel(
  (net): Sequential(
    (0): Linear(
      in_features=32, out_features=50, bias=True
      (weight_prior): NormalGamma(
        (scale): Gamma()
      )
      (bias_prior): Normal()
    )
    (1): ReLU()
    (2): Linear(
      in_features=50, out_features=50, bias=True
      (weight_prior): NormalGamma(
        (scale): Gamma()
      )
      (bias_prior): Normal()
    )
    (3): ReLU()
    (4): Linear(
      in_features=50, out_features=10, bias=True
      (weight_prior): NormalGamma(
        (scale): Gamma()
      )
      (bias_prior): Normal()
    )
  )
)

In [28]:
model.net[0].weight_prior.p

Parameter containing:
tensor([[-0.0391, -0.0748,  0.4394,  ...,  0.0997, -0.2732, -0.4596],
        [ 0.1924, -0.4597,  0.6601,  ...,  0.2342,  0.0754, -0.1371],
        [-0.2355,  0.1034,  0.2625,  ...,  0.1294, -0.4172,  0.3931],
        ...,
        [ 0.5610, -0.2617,  0.1028,  ..., -0.0586, -0.5406, -0.1236],
        [ 0.1791, -0.2281,  0.2730,  ...,  0.0107,  0.4770,  0.1286],
        [ 0.0903, -0.2282,  0.1181,  ...,  0.1988,  0.1022, -0.1525]],
       device='cuda:0', requires_grad=True)

In [29]:
model.net[0].weight_prior.scale.p

Parameter containing:
tensor(-1.2587, device='cuda:0', requires_grad=True)

In [30]:
model.net[0].weight_prior.scale()

tensor(0.2500, device='cuda:0', grad_fn=<SoftplusBackward>)

In [31]:
model.log_prior()

tensor(311.8788, device='cuda:0', grad_fn=<AddBackward0>)

In [32]:
model.net[0].weight_prior.log_prob()

tensor(-111.6239, device='cuda:0', grad_fn=<AddBackward0>)

In [33]:
model.net[0].weight_prior.scale.log_prob()

tensor(-0.4983, grad_fn=<SumBackward0>)

In [34]:
model.net[0].weight_prior._dist_obj().mean

tensor(0., device='cuda:0')

In [35]:
model.net[0].weight_prior.scale._dist_obj().mean

tensor(0.2500)

In [36]:
torch.random.manual_seed(1337)

prior_n = Normal([12,10], 0., 1.)
prior_ng = NormalGamma([12,10], 0., 1.)

In [37]:
prior_n._dist_obj().mean

tensor(0.)

In [38]:
prior_ng._dist_obj().mean

tensor(0.)

In [39]:
prior_n.cuda()
prior_ng.cuda()

NormalGamma(
  (scale): Gamma()
)

In [40]:
prior_n.log_prob()

tensor(-170.5846, device='cuda:0', grad_fn=<SumBackward0>)

In [41]:
prior_ng.log_prob()

tensor(-176.1476, device='cuda:0', grad_fn=<AddBackward0>)

In [42]:
prior_ng.scale.log_prob()

tensor(-1., grad_fn=<SumBackward0>)

In [43]:
list(prior_n.parameters())

[Parameter containing:
 tensor([[ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643,
           0.3612,  1.1679],
         [-1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433,  1.3488, -0.1396,
           0.2858,  0.9651],
         [-2.0371,  0.4931,  1.4870,  0.5910,  0.1260, -1.5627, -1.1601, -0.3348,
           0.4478, -0.8016],
         [ 1.5236,  2.5086, -0.6631, -0.2513,  1.0101,  0.1215,  0.1584,  1.1340,
          -1.1539, -0.2984],
         [-0.5075, -0.9239,  0.5467, -1.4948, -1.2057,  0.5718, -0.5974, -0.6937,
           1.6455, -0.8030],
         [ 1.3514, -0.2759, -1.5108,  2.1048,  2.7630, -1.7465,  1.4516, -1.5103,
           0.8212, -0.2115],
         [ 0.7789,  1.5333,  1.6097, -0.4032, -0.8345,  0.5978, -0.0514, -0.0646,
          -0.4970,  0.4658],
         [-0.2573, -1.0673,  2.0089, -0.5370,  0.2228,  0.6971, -1.4267,  0.9059,
           0.1446,  0.2280],
         [ 2.4900, -1.2237,  1.0107,  0.5560, -1.5935, -1.2706,  0.6903, -0.1961,
       

In [44]:
list(prior_ng.children())

[Gamma()]

In [45]:
prior_g = Gamma([12,10], concentration=1., rate=1.)

In [46]:
prior_g.log_prob()

tensor(-109.7370, grad_fn=<SumBackward0>)

In [47]:
prior_g.cuda()

Gamma()

In [48]:
prior_g.log_prob()

tensor(-109.7370, grad_fn=<SumBackward0>)

In [49]:
prior_u = Uniform([12,10], low=0., high=1.)

In [50]:
prior_u.log_prob()

-0.0

In [51]:
prior_u.cuda()

Uniform()

In [52]:
prior_u.log_prob()

-0.0