In [1]:
import torch
import pyro
from pyro.contrib.gp import models
import pyro.distributions as dist
from torch.nn import Parameter
from pyro.optim import Adam
from pyro.infer import SVI
from pyro.contrib.gp.util import Parameterized
from torch.distributions import constraints
import numpy as np

The Goal: Understand MAP interfences in Pyro.
Reusing our coin flipping example in last tutorial, only this time we want to use Maximise A Posterior(MAP), 

In [2]:
class Coin(Parameterized):
    def __init__(self):
        super().__init__()
        self.log_alpha = Parameter(torch.tensor([10.0]))
        self.log_beta = Parameter(torch.tensor([10.0]))
        self.set_constraint("log_alpha", constraints.positive)
        self.set_constraint("log_beta", constraints.positive)
        
        
    def model(self, data):
        self.set_mode('model')
        alpha = self.get_param('log_alpha')
        beta = self.get_param('log_beta')
        f = pyro.sample('latent_fairness', dist.Beta(torch.exp(alpha), torch.exp(beta)))
        for i in range(len(data)):
            pyro.sample('obs_{}'.format(i), dist.Bernoulli(f), obs=data[i])
    
    def guide(self, data):
        self.set_mode('guide')
        alpha = self.get_param('log_alpha')
        beta = self.get_param('log_beta')
        f = pyro.sample('latent_fairness', dist.Beta(torch.exp(alpha), torch.exp(beta)))
        #pyro.sample('latent_fairness', dist.Delta, f)
        

Next, we define the data that we observe when flipping 

In [3]:
# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.ones(1))
for _ in range(4):
    data.append(torch.zeros(1))


We use the same procedure in SVI tutorial I to optimise 

In [4]:
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

coin = Coin()

# setup the inference algorithm
svi = SVI(coin.model, coin.guide, optimizer, loss="ELBO")

n_steps = 4000
# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

........................................

In [5]:
alpha_q = torch.exp(coin.get_param("log_alpha")).data.numpy()[0]
beta_q = torch.exp(coin.get_param("log_beta")).data.numpy()[0]

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * np.sqrt(factor)


In [6]:
print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


based on the data and our prior belief, the fairness of the coin is 0.601 +- 0.002
