In [1]:
import torch
import torch.nn as nn
from src.utils import SequentialBuilder
from abc import abstractmethod, ABC
from torch.distributions import Normal
from src.data.mnist import MNISTDataModule
import pytorch_lightning as pl
import torch.nn.functional as F
import torchmetrics.functional as FM
import torchmetrics
from typing import Dict
from torch import Tensor

# Base model

In [2]:
class Model(ABC, nn.Module):

    @abstractmethod
    def observation_model_gvn_output(self, output: Tensor):
        """Returns p(y | x, theta) given the model output, f(x)"""

    def observation_model(self, x: Tensor):
        """Returns p(y | x, theta) given observation x"""
        return self.observation_model_gvn_output(self.forward(x))

    def loss(self, output: Tensor, targets: Tensor):
        """General loss implementation given model output"""
        return -self.observation_model_gvn_output(output).log_prob(y).mean()

    def get_metrics(self) -> Dict[str, torchmetrics.Metric]:
        """Metrics relevant for model"""
        return {}

class ErrorRate(torchmetrics.Accuracy):

    def compute(self) -> Tensor:
        return 1 - super().compute()


class ClassifierMixin:

    def observation_model_gvn_output(self, logits: torch.FloatTensor):
        return torch.distributions.Categorical(logits=logits)

    def loss(self, output: torch.FloatTensor, y: torch.FloatTensor):
        return F.cross_entropy(output, y)
    
    def get_metrics(self):        
        return {"err" : ErrorRate()}

class MLPModel(Model):

    def __init__(
        self,
        in_features=784,
        out_features=10,
        hidden_layers=[100],
        alpha=1.0,
        beta=1.0,
        precision=1.0,
    ):

        super().__init__()

        seq_builder = SequentialBuilder(in_shape=(in_features,))
        for hidden_size in hidden_layers:
            seq_builder.add(nn.Linear(seq_builder.out_dim(0), hidden_size))
            seq_builder.add(nn.Sigmoid())
        seq_builder.add(nn.Linear(seq_builder.out_dim(0), out_features))

        self.ffnn = seq_builder.build()

    def forward(self, x: torch.Tensor):

        x = x.flatten(-2, -1)
        return self.ffnn(x)

class MLPClassifier(ClassifierMixin, MLPModel):
    ...

# Defining probabilistic models

tensor(-73157.2422, grad_fn=<AddBackward0>)


# Inference

In [5]:
class InferenceModule(pl.LightningModule):
    ...

def test_inference_class(inference_cls: InferenceModule):

    dm = MNISTDataModule()
    model = MLPClassifier()

    inference = inference_cls(model)
    pl.Trainer(max_epochs=5).fit(inference, dm)

## Standard SGD

In [6]:
class SGDInference(InferenceModule):

    def __init__(self, model : Model, lr: float=1e-3):

        super().__init__()
        self.model = model
        self.lr = lr

        self.train_metrics = self.model.get_metrics()
        self.val_metrics = self.model.get_metrics()

        # TODO: Refactor later, probably in a factory?
        self.save_hyperparameters({"inference_type" : "SGD"})
    
    def training_step(self, batch, batch_idx):

        x, y = batch
        output = self.model(x)
        loss = self.model.loss(output, y)
        
        self.log("loss/train", loss)
        for name, metric in self.train_metrics.items():
            self.log(f"{name}/train", metric(output, y), on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch
        output = self.model(x)

        for name, metric in self.train_metrics.items():
            self.log(f"{name}/val", metric(output, y))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## Variational Inference


In [201]:
from src.utils import pairwise
from itertools import accumulate
import math

class ParameterView:

    def __init__(self, model : nn.Module):

        self.model = model
        self.param_shapes = {k: x.shape for k, x in self.model.named_parameters()}
        indices = accumulate(
            self.param_shapes.values(), lambda x, y: x + math.prod(y), initial=0
        )
        self.flat_index_pairs = list(pairwise(indices))

        self.n_params = self.flat_index_pairs[-1][-1]

    def __getitem__(self, key):

        if type(key) is slice:
            return self._get_slice(key)
        else:
            raise NotImplementedError

    def __setitem__(self, key, value):
        
        if type(key) is slice:
            self._set_slice(key, value)
    
    @property
    @torch.no_grad()
    def flat_grad(self):
        return self._flatten(x.grad for x in self.model.parameters())

    def _get_slice(self, slice_):

        if slice_.start is None and slice_.stop is None and slice_.step is None:
            return self._flatten(self.model.parameters())
        else:
            raise NotImplementedError

    def _set_slice(self, slice_, value):

        if slice_.start is None and slice_.stop is None and slice_.step is None:
            state_dict = self._unflatten(value)
            for name, parameter in self.model.named_parameters():
                parameter.copy_(state_dict[name])
        else:
            raise NotImplementedError
            
    def _flatten(self, tensor_iter):
        return torch.cat([x.flatten() for x in self.model.parameters()])

    def _unflatten(self, value):
        return {
            k: value[a:b].view(shape)
            for (k, shape), (a, b) in zip(
                self.param_shapes.items(), self.flat_index_pairs
            )
        }

    def apply(self, fnc):
        for param in self.model.parameters():
            fnc(param)


In [193]:
dm = MNISTDataModule()
dm.setup()
x, y = next(iter(dm.train_dataloader()))

In [249]:
for param in p_model.parameters():
    param

In [252]:
param.is_leaf = False

AttributeError: attribute 'is_leaf' of 'torch._C._TensorBase' objects is not writable

In [None]:
nn.Parameter()

In [268]:
p_model = ProbabilisticModel(MLPClassifier())
parameter_view = ParameterView(p_model)

# Need grad for variational parameters
p_model.zero_grad()
parameter_view.apply(lambda x: setattr(x, "requires_grad", False))
rho = torch.zeros(parameter_view.n_params, requires_grad=True)
mu = torch.zeros(parameter_view.n_params, requires_grad=True)

# Resample parameters
sigma = torch.log(1 + rho.exp())
eps = torch.randn_like(mu)
w = mu + sigma * eps
parameter_view[:] = w

# Retain weight/bias grads
parameter_view.apply(lambda x: x.retain_grad())

elbo = Normal(mu, sigma).log_prob(w).sum() - p_model.log_likelihood(x, y).sum() - p_model.log_prior()
elbo.backward()

In [239]:
for parameter in p_model.parameters():
    parameter.retain_grad()

In [269]:
class VariationalInference(InferenceModule):

    def __init__(self, model : Model, lr: float=1e-3):

        super().__init__()

        self.p_model = ProbabilisticModel(model)

        self.lr = lr

        self.parameter_view = ParameterView(self.p_model)
        self.register_buffer("rho", torch.zeros(parameter_view.n_params, requires_grad=True))
        self.register_buffer("mu", torch.zeros(parameter_view.n_params, requires_grad=True))

        self.train_metrics = self.p_model.model.get_metrics()
        self.val_metrics = self.p_model.model.get_metrics()

        # TODO: Refactor later, probably in a factory?
        self.save_hyperparameters({"inference_type" : "VI"})
    
    def training_step(self, batch, batch_idx):

        x, y = batch
        parameter_view.apply(lambda x: setattr(x, "requires_grad", False))
        rho = torch.zeros(parameter_view.n_params, requires_grad=True)
        mu = torch.zeros(parameter_view.n_params, requires_grad=True)

        # Resample parameters
        sigma = torch.log(1 + rho.exp())
        eps = torch.randn_like(mu)
        w = mu + sigma * eps
        parameter_view[:] = w

        # Retain weight/bias grads
        parameter_view.apply(lambda x: x.retain_grad())

        elbo = Normal(mu, sigma).log_prob(w).sum() - p_model.log_likelihood(x, y).sum() - p_model.log_prior()

        return elbo

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [270]:
vi = VariationalInference(MLPClassifier())
pl.Trainer(fast_dev_run=True).fit(vi, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).
  rank_zero_warn(f"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop")
  rank_zero_deprecation(

  | Name    | Type               | Params
-----------------------------------------------
0 | p_model | ProbabilisticModel | 79.5 K
-----------------------------------------------
79.5 K    Trainable params
0         Non-trainable params
79.5 K    Total params
0.318     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/1 [00:00<00:00, 1201.81it/s]  

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

In [263]:
list(vi.buffers())

[tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True),
 tensor(1.),
 tensor(0.),
 tensor(1.),
 tensor(0.),
 tensor(1.),
 tensor(0.),
 tensor(1.),
 tensor(0.)]

In [227]:
with torch.no_grad():
    d_mu = parameter_view.flat_grad + mu.grad
    d_rho = parameter_view.flat_grad * (eps / (1 + torch.exp(-rho))) + rho.grad 

In [228]:
d_mu

tensor([-0.5299, -1.3379,  0.6065,  ..., -1.9135, -5.3752,  0.0174])

In [229]:
d_rho

tensor([-0.6201, -0.0757, -0.5887,  ..., -1.6976,  1.9461, -0.7208])

In [121]:
a = torch.tensor(1., requires_grad=True)
b = torch.empty_like(a)
b.copy_(a)
b

tensor(1., grad_fn=<CopyBackwards>)

In [None]:
class VariationalInference(InferenceModule):

    def __init__(self, model : Model, lr: float=1e-3):

        super().__init__()
        self.model = ProbabilisticModel(model)







    #     self.lr = lr

    #     self.train_metrics = self.model.get_metrics()
    #     self.val_metrics = self.model.get_metrics()

    #     # TODO: Refactor later, probably in a factory?
    #     self.save_hyperparameters({"inference_type" : "VI"})
    
    # def training_step(self, batch, batch_idx):

    #     x, y = batch
    #     output = self.model(x)
    #     loss = self.model.loss(output, y)
        
    #     self.log("loss/train", loss)
    #     for name, metric in self.train_metrics.items():
    #         self.log(f"{name}/train", metric(output, y), on_epoch=True, on_step=False)

    #     return loss

    # def validation_step(self, batch, batch_idx):

    #     x, y = batch
    #     output = self.model(x)

    #     for name, metric in self.train_metrics.items():
    #         self.log(f"{name}/val", metric(output, y))

    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
    #     return optimizer

In [33]:

mu = p_model

In [32]:
mu 

ProbabilisticModel(
  (model): MLPClassifier(
    (ffnn): Sequential(
      (0): Linear(
        in_features=784, out_features=100, bias=True
        (priors): ModuleDict(
          (weight): KnownPrecisionNormalPrior()
          (bias): KnownPrecisionNormalPrior()
        )
      )
      (1): Sigmoid()
      (2): Linear(
        in_features=100, out_features=10, bias=True
        (priors): ModuleDict(
          (weight): KnownPrecisionNormalPrior()
          (bias): KnownPrecisionNormalPrior()
        )
      )
    )
  )
)