# Variational Inference overview

## Existing Variational Inference implementation

The best way to get a sense for the current implementation is to walk backwards from how it's used

In [1]:
import numpy as np
import pymc as pm
import arviz as az

In [2]:
data = np.random.normal(size=10_000)

In [3]:
with pm.Model() as model:
    d = pm.Data("data", data)
    batched_data = pm.Minibatch(d, batch_size=100)
    x = pm.Normal("x", 0., 1.)
    y = pm.Normal("y", x, total_size=len(data), observed=batched_data)

In [4]:
with model:
    idata = pm.fit(n=10_000, method="advi") 

Output()

Finished [100%]: Average Loss = 144.77


But what does fit do? It roughly dispatches on the method. So the above is roughly equalivalent to:

In [5]:
with model:
    advi = pm.ADVI()
    idata = advi.fit(n=100_000)

Output()

Finished [100%]: Average Loss = 143.83


But what is this `ADVI` object? Well, if you look at it's implementation with the documentation removed, you see it's a type of `KLqp`

````python
class ADVI(KLqp):
    def __init__(self, *args, **kwargs):
        super().__init__(MeanField(*args, **kwargs))
````

So what's a `Klqp`? Look at it's implementation with the documentation removed, you see it's an Inference object

````python
class KLqp(Inference):
    def __init__(self, approx, beta=1.0):
        super().__init__(KL, approx, None, beta=beta)
````

So what's an `Inference` object? Look at it's implementation with the documentation removed we finally get a sense for what are the main abstraction we will be working with.

In [6]:
pm.Inference?

[0;31mInit signature:[0m [0mpm[0m[0;34m.[0m[0mInference[0m[0;34m([0m[0mop[0m[0;34m,[0m [0mapprox[0m[0;34m,[0m [0mtf[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
**Base class for Variational Inference**.

Communicates Operator, Approximation and Test Function to build Objective Function

Parameters
----------
op : Operator class    #:class:`~pymc.variational.operators`
approx : Approximation class or instance    #:class:`~pymc.variational.approximations`
tf : TestFunction instance  #?
model : Model
    PyMC Model
kwargs : kwargs passed to :class:`Operator` #:class:`~pymc.variational.operators`, optional
[0;31mFile:[0m           ~/upstream/pymc/pymc/variational/inference.py
[0;31mType:[0m           type
[0;31mSubclasses:[0m     KLqp, ImplicitGradient

Now things are falling into place. The `Inference` class is the way we perform variational inference. This is where the actual fit machinery lives. It also highlights what we need to do variational inference. We need a `Model`, an `Operator`, and an `Approximation`. We already know for `ADVI`, that the `Operator` is `KL` and the `Approximation` is `MeanField`.

But what do these things mean? And how are they combined to perform inference?

Well the `__init__` method of `Inference` makes it where we can find our answer

In [7]:
pm.Inference.__init__??

[0;31mSignature:[0m [0mpm[0m[0;34m.[0m[0mInference[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mop[0m[0;34m,[0m [0mapprox[0m[0;34m,[0m [0mtf[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Initialize self.  See help(type(self)) for accurate signature.
[0;31mSource:[0m   
    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mop[0m[0;34m,[0m [0mapprox[0m[0;34m,[0m [0mtf[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mhist[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0masarray[0m[0;34m([0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mobjective[0m [0;34m=[0m [0mop[0m[0;34m([0m[0mapprox[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m([0m[0mtf[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0

Alright, so let's go ahead and explore the operator `KL`

````python
class KL(Operator):
    def __init__(self, approx, beta=1.0):
        super().__init__(approx)
        self.beta = pm.floatX(beta)

    def apply(self, f):
        return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)
````

We see no `__call__` but we see a call to the `__init__` of `Operator`. For the `apply` method we see what looks like the ELBO. Let's for now inline for `ADVI` case and see what we get

In [8]:
objective = pm.operators.KL(pm.MeanField(model=model))(None)
objective

<pymc.variational.opvi.ObjectiveFunction at 0x70ee59332810>

So how'd that happen? Well if you look in the `Objective` class you see

````python
    objective_class = ObjectiveFunction

    def __call__(self, f=None):
        if self.has_test_function:
            if f is None:
                raise ParametrizationError(f"Operator {self} requires TestFunction")
            else:
                if not isinstance(f, TestFunction):
                    f = TestFunction.from_function(f)
        else:
            if f is not None:
                warnings.warn(f"TestFunction for {self} is redundant and removed", stacklevel=3)
            else:
                pass
            f = TestFunction()
        f.setup(self.approx)
        return self.objective_class(self, f)
````

Which finally brings us to `ObjectiveFunction`

This is the function that sets up the actual loss functions and does the updates on it.

In [9]:
pm.opvi.ObjectiveFunction.step_function?

[0;31mSignature:[0m
[0mpm[0m[0;34m.[0m[0mopvi[0m[0;34m.[0m[0mObjectiveFunction[0m[0;34m.[0m[0mstep_function[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mobj_n_mc[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtf_n_mc[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mobj_optimizer[0m[0;34m=[0m[0;34m<[0m[0mfunction[0m [0madagrad_window[0m [0mat[0m [0;36m0x70ee648da480[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtest_optimizer[0m[0;34m=[0m[0;34m<[0m[0mfunction[0m [0madagrad_window[0m [0mat[0m [0;36m0x70ee648da480[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmore_obj_params[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmore_tf_params[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmore_updates[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmore_repl

## Proposed Improvements

There is a lot to like here, but there is also a lot of indirection. Further, much of it isn't used for the `ADVI` case. This is all in service of `SVGD` and `ASVGD`

Further, the `Inference` class has to be aware of too many of these details. Ideally the `Inference` should be reworked to only take in a step function. It could be re-named `Trainer` to match what's in PyTorch Lightning. I think forcing all `VI` through `OPVI` makes it more challenging to write and port new `VI` algorithms to `pymc`

### PyTorch Lightning and Optax optimization

How would this look? One possibility is having each Variational Inference technique encapsulated into an object that takes a model and optimizer as inputs and provides a step function as a method.

````python
class ADVI(Inference):
    def __init__(self, model=None, optimizers=None):
        ...

    def step(self, batch):
        ...
        return loss
````

This is then passed to a `Trainer` object for fitting

````python
with model:
    trainer = Trainer(method=ADVI(), dataloader= ...)
    trainer.fit(n=10_000)
````

Under this setup most of the optimization logic moves into the `__init__` and `step` methods. As for how those should happen. I think this can be handled separately. But something like optax might not be so bad. So we could end with code that resembles the below

````python
class ADVI(Inference):
    def __init__(self, model=None, optimizers=None):
        if model is None:
            model = modelcontext(None)
        if optimizers is None:
            optimizers = [pm.opt.Adam(1e-3)]
        self.optimizer = optimizers[0]
        self.params = self.optimizer.init(model.basic_RVs)

    def step(self, batch):
        loss = self.loss_function(self.params, batch)
        grads = grad(loss)
        self.params = self.optimizer.update(grads, self.params)
        return loss
````

### Model and Guide programs

Additionally it would be nice if we could easily suppose variational inference with guide programs ala pyro/numpyro

The way this could look is we define both as `pymc` models and then pass them to a `SVI` method

````python
with pm.Model() as model:
    data = pm.Data("data", ...)
    x = pm.Normal("x", 0, 1)
    y = pm.Normal("y", x, 1, observed=data)

with pm.Model() as guide:
    mu = pt.tensor("mu", param=True)
    sd = pt.tensor("sd", param=True)
    x = pm.Normal("x", mu, sd)


with model:
    trainer = Trainer(method=SVI(model, guide), dataloader= ...)
    trainer.fit(n=10_000)
````

Naturally, `SVI` is a very general inference method, and in fact we can re-define `ADVI` in terms of it. Following the lead of pyro/numpyro we can have a guide generation

````python
with model:
    guide = AutoGuide(model)
    trainer = Trainer(method=SVI(model, guide), dataloader= ...)
    trainer.fit(n=10_000)
````

### Reworking Minibatch

Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.

I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.

````python
with pm.Model() as model:
    data = pm.Data("data", None)
    x = pm.Normal("x", 0, 1)
    y = pm.Normal("y", x, 1, observed=data)

dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)

with model:
    trainer = Trainer(method=ADVI(), dataloader=dataloader)
    trainer.fit(n=10_000)
````

Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.

````python
class ADVI(Inference):
    def step(self, batch):
        self.model.set_data("data", batch)
        ...
````