In [3]:
%load_ext autoreload
%autoreload 2

from abc import ABC, abstractmethod
from torch import nn
import torch
from torch.distributions import Gamma, Normal

class BayesianMixin(ABC):

    @abstractmethod
    def setup_prior(self):
        pass

    @abstractmethod
    def log_prior(self):
        pass

class BayesianLinear(BayesianMixin, nn.Linear):

    precision : torch.Tensor
    alpha : torch.Tensor
    beta : nn.Parameter

    def setup_prior(self, alpha, beta):
        
        self.register_parameter("precision", nn.Parameter(torch.tensor(1.)))
        self.register_buffer("alpha", torch.tensor(alpha))
        self.register_buffer("beta", torch.tensor(beta))
        
        return self

    def log_prior(self):

        precision_d = Gamma(self.alpha, self.beta)
        param_d = Normal(0, self.precision)

        return (
            precision_d.log_prob(self.precision) 
            + param_d.log_prob(self.weight).sum()
            + param_d.log_prob(self.bias).sum()
        )

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
from src.models.simple import PolynomialToyModel

In [41]:
true_coeffs = torch.tensor([1.0, 2.0, 0.0, -1.0])
t = PolynomialToyModel(true_coeffs)

In [43]:
d = t.generate_data()

In [44]:
d.tensors

(tensor([[-1.0194],
         [ 0.0828],
         [-1.2416],
         [ 0.9428],
         [-2.1301],
         [ 1.8326],
         [-1.8171],
         [-1.9876],
         [-1.5797],
         [ 1.1322],
         [-0.9237],
         [ 0.9355],
         [-2.1218],
         [-1.5168],
         [-0.9179],
         [-0.4913],
         [-1.9072],
         [ 1.6370],
         [-0.5896],
         [ 0.8025],
         [ 1.7679],
         [ 0.4658],
         [ 0.6836],
         [ 2.4131],
         [-1.1275],
         [ 0.7919],
         [-1.1123],
         [ 1.7866],
         [ 1.9966],
         [-2.3049],
         [ 2.1341],
         [ 1.1938],
         [ 1.0894],
         [ 1.0292],
         [ 2.0782],
         [-0.3301],
         [-2.1142],
         [-0.7174],
         [-1.7607],
         [ 0.1653],
         [-0.4668],
         [-1.3410],
         [-0.2273],
         [ 2.3685],
         [-0.1972],
         [ 0.0794],
         [-0.3899],
         [ 0.3930],
         [ 2.2275],
         [ 1.5287],


In [67]:
class BayesianModel(nn.Module):

    def log_prior(self):
        return sum(

            m.log_prior() for m in self.modules() if isinstance(m, BayesianMixin)
        )
    @abstractmethod
    def log_likelihood(self):
        pass






In [72]:
net.log_likelihood(torch.Tensor([[1., 2., 3., 4.]]), torch.tensor([2]))

tensor(-1.0619, grad_fn=<SumBackward0>)

In [47]:
from src.models.simple import ClassifierNet
import copy
net = ClassifierNet(4, 3, [200, 200, 200])

In [51]:
%%timeit
t = copy.deepcopy(net)

814 µs ± 74.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [55]:
%%timeit
t = 

114 µs ± 806 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [11]:
for param in t.parameters():
    print(t.prior_distribution.log_prob(param))

tensor([[-0.9192, -0.9191, -0.9190,  ..., -0.9193, -0.9194, -0.9189],
        [-0.9194, -0.9190, -0.9189,  ..., -0.9193, -0.9193, -0.9190],
        [-0.9195, -0.9191, -0.9189,  ..., -0.9189, -0.9196, -0.9189],
        ...,
        [-0.9194, -0.9193, -0.9190,  ..., -0.9190, -0.9195, -0.9191],
        [-0.9195, -0.9189, -0.9190,  ..., -0.9192, -0.9189, -0.9191],
        [-0.9190, -0.9190, -0.9189,  ..., -0.9195, -0.9189, -0.9190]],
       grad_fn=<SubBackward0>)
tensor([-0.9192, -0.9189, -0.9191, -0.9189, -0.9191, -0.9189, -0.9189, -0.9194,
        -0.9190, -0.9189, -0.9190, -0.9190, -0.9194, -0.9190, -0.9193, -0.9191,
        -0.9190, -0.9191, -0.9191, -0.9190, -0.9190, -0.9195, -0.9193, -0.9190,
        -0.9189, -0.9191, -0.9195, -0.9192, -0.9194, -0.9193, -0.9191, -0.9191,
        -0.9190, -0.9190, -0.9190, -0.9189, -0.9190, -0.9192, -0.9194, -0.9189,
        -0.9193, -0.9190, -0.9189, -0.9189, -0.9191, -0.9191, -0.9190, -0.9192,
        -0.9190, -0.9190, -0.9190, -0.9195, -0.9190, -0