# Introduction to Monad Bayes, Part 4: Inference Representations

## Imports

In [27]:
:e OverloadedStrings

import qualified Graphics.Vega.VegaLite as VL

import Control.Monad (replicateM, forM_)

import Control.Monad.Bayes.Class (score, normalPdf, normal, uniform, MonadSample, MonadInfer)
import Control.Monad.Bayes.Free (FreeSampler)
import Control.Monad.Bayes.Enumerator (Enumerator)
import Control.Monad.Bayes.Sampler (SamplerIO, SamplerST, sampleIO, sampleIOwith, sampleIOfixed, runSamplerST, sampleST, sampleSTfixed)
import Control.Monad.Bayes.Traced (Traced, mh)
import Control.Monad.Bayes.Sequential (Sequential)
import Control.Monad.Bayes.Population (Population, runPopulation)
import Control.Monad.Bayes.Weighted (Weighted, prior)
import Control.Monad.Bayes.Inference.PMMH (pmmh)
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (smcMultinomial, smcSystematic, smcMultinomialPush)
import Control.Monad.Bayes.Inference.SMC2 (smc2)
import Numeric.Log (Log(Exp))

## Introduction

This post is a continuation of Tweag's _**Probabilistic Programming with monad‑bayes Series**_. You can find

* [Part 1, Introduction](https://www.tweag.io/posts/2019-09-20-monad-bayes-1.html).
* [Part 2, Linear Regression](https://www.tweag.io/posts/2019-11-08-monad-bayes-2.html)
* [Part 3, Neural Networks](https://www.tweag.io/posts/2020-02-26-monad-bayes-3.html)

Want to make this post interactive? Try our [notebook version](https://github.com/tweag/blog-resources/tree/master/monad-bayes-series). It includes a Nix shell, the required imports, and some helper routines for plotting. Let's start modeling!

## Motivation

Bayes theorem provides us with a powerful modelling tool.
Computing the posterior distribution of model parameters follows a standardized computation that can simply be followed through.
Except, that it is not easy to actually execute this computation because it requires computing with distributions.
While adding, subtracting, multiplying or dividing numbers is straight forward with nowadays programming languages, computing with distributions is much less common.
How does `monad-bayes` compute with distributions? This is the topic of this post.

If you think about it, even computing with numbers isn't that straight forward.
First, you have the choice between various number representations, such as `Float`, `Double`, `Int`, `Log Double`.
Then, each of these representations support different operations, with different efficiency and with different accuracy.
For example, an `Int` might not support arbitrary divisions, and a `Float` might be more efficient but less accurate than a `Double`.
We can also _compose new representations from existing ones_, for example a `Log Double` type, to get representations that are tailored to a specific use case.
The same computation, for example `\x -> 1 + x` can then be executed with any of these representations.
In Haskell, we use type classes to define operations that are shared between different representations, that is types.
Types define the internal data representation, and their associated class instances the operations.
For example, all number representations mentioned above are instances of the `Num` type class and come with implementations of `+`, `-`, `*`, `/` that know how to deal with the underlying respective data structures.

Similarly to number representations that support the operations defined in `Num`, distributions can be represented in different ways sharing common operations.
This provides us with flexibility to adopt a generic probabilistic computation, as the ones imposed by Bayes' theorem, on the fly to different situations.

Let's dive into this!

## Introduction

We will need some simple example setup to continue.
Let's setup likelihood, prior and posterior like this:

In [6]:
type Param = Double
type Point = Double

likelihood :: Param -> Point -> Log Double
likelihood = normalPdf mean
    where mean = 0.0

prior :: MonadSample m => m Double
prior = uniform 0.1 5

post :: MonadInfer m => [Point] -> m Param
post obs = do
  std <- prior
  forM_ obs (score . likelihood std) -- score observation after observation
  return std

here is the posterior for some zero-mean observations:

In [7]:
points = [0.1, -0.2, 0.3, 0.2, 0.0, -0.4]
posterior = post points

## Sampling Representations

In [8]:
:t prior

`prior` is any type that implements the `MonadSample` interface.
Let's see what representations we have:

In [15]:
prior1 = prior :: SamplerIO Double
prior2 = prior :: SamplerST Double
prior3 = prior :: Enumerator Double

We have some operations that are _not_ polymorphic. They need a specific type:

In [24]:
:t sampleIO
:t sampleIOfixed
:t sampleIOwith

In [28]:
:t sampleST
:t sampleSTfixed
:t runSamplerST

## Inference Representations

We have seen in multiple examples now how we can _use_ a probabilistic computation such as `post`.
But what _is_ `post`?
What type does it have?
All we now is that it's type is constraint with `MonadInfer m => m [Params]` to be an instance of `MonadInfer`, and that it returns a list of parameters.

In [4]:
post1 = posterior :: Weighted SamplerIO Param
post2 = posterior :: Traced (Weighted SamplerIO) Param
post3 = posterior :: Population (Weighted SamplerIO) Param
post4 = posterior :: Sequential (Population (Weighted SamplerIO)) Param
post5 = posterior :: Sequential (Population (Traced (Weighted SamplerIO))) Param

#### SamplerIO - elementary sampling operations

```haskell
newtype SamplerIO a = SamplerIO (ReaderT GenIO IO a)
instance MonadSample SamplerIO -- Defined in ‘Control.Monad.Bayes.Sampler’
```

In [5]:
post6 = posterior :: SamplerIO Param

: 

#### Weighted - a state variable to track likelihood

```haskell
newtype Weighted (m :: * -> *) a = Weighted (StateT (Log Double) m a)
instance MonadSample m => MonadInfer (Weighted m)
instance MonadSample m => MonadSample (Weighted m)
```

#### Traced - state variable with the possibility

```haskell
data Traced (m :: * -> *) a = Traced (Weighted (FreeSampler m) a) (m (Trace a))
instance MonadInfer m => MonadInfer (Traced m)
instance MonadSample m => MonadSample (Traced m)
```

In [6]:
post7 = posterior :: Traced SamplerIO Param
-- ^ doesn't compose because Traced doesn't support score (not an instance of MonadCond and MonadInfer typeclass)

: 

#### Population - Run many computations in parallel

```haskell
newtype Population (m :: * -> *) a = Population (Weighted (ListT m) a)
instance MonadSample m => MonadInfer (Population m)
instance MonadSample m => MonadSample (Population m)
```

#### Sequential - stop and do something after a new observation

```haskell
newtype Sequential (m :: * -> *) a = Sequential {runSequential :: Coroutine (Await ()) m a}
instance MonadInfer m => MonadInfer (Sequential m)
instance MonadSample m => MonadSample (Sequential m)
```

## What is a Sampler?

We now understand that various types can represent a probabilistic computation that is simply defined by having access to a sampling and a scoring operation.
We can build new types by composing various building blocks together.
Each of these building blocks adds functionality to interact with some underlying data representation.
For example, the `Weighted` Monad provides operations that modify a state variable.
The `Traced` sampler adds operations that access a prior execution trace.
A sampler interprets some of these interactions and replaces them with more basic operations.
Here are the steps that it takes:

* build probabilistic computation
* chose a representation for the computation that has access to abstract operations such as State etc..
* chose a sampler that reduces the abstractions to basic operations
* once we have only elementary IO sampling operations left, we run the computation to get a sample.

These steps are reduced to

* build probabilistic computation that can be represented with any MonadInfer type
* chose a sampler that returns elementary sampling operations and feed it with the computation. Since the computation is polymorphic, the sampler will (supported by Haskell type inference system) automatically choses the required representation that it needs.

#### elementary sampling (sampleIOfixed)

In [28]:
:t sampleIOfixed

In [16]:
samples = sampleIOfixed $ uniform 0 1
samples

2.481036288296201e-2

In [18]:
samples = sampleIOfixed $ replicateM 3 $ uniform 0 1
samples

[2.481036288296201e-2,0.7408640679453008,0.15936354678388287]

#### remove weight (prior)

In [29]:
:t prior

In [35]:
posteriorNoWeight = prior posterior
:t posterior
:t posteriorNoWeight

In [34]:
samples = sampleIOfixed posteriorNoWeight
samples

0.22157077812651388

#### metropolis-hastings (mh)

In [36]:
:t mh

In [22]:
nsamples = 50
samples = sampleIOfixed $ prior $ mh nsamples posterior
samples

[0.3889676437623444,0.3889676437623444,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.30795313989187806,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.1746032034868105,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.22157077812651388,0.221570778126

#### Sequential importance resampling (smcMultinomial, smcSystematic, smcMultinomialPush, smcSystematicPush)

In [53]:
:t smcMultinomial

In [58]:
samples = sampleIOfixed $ runPopulation $ smcMultinomial 6 300 posterior
l = length <$> samples
l

300

#### Resample-move Sequential Monte Carlo (rmsmc)

In [43]:
:t rmsmc
:t posterior

In [52]:
samples = sampleIOfixed $ runPopulation $ rmsmc 6 300 5 posterior
l = length <$> samples
l

300

#### Particle Marginal Metropolis-Hastings (pmmh)

In [37]:
:t pmmh

In [39]:
samples = sampleIOfixed $ prior $ pmmh 5 6 300 posterior
samples

: 

#### Sequential Monte Carlo squared (smc2)

In [59]:
:t smc2