In [1]:
## Import packages

from collections import OrderedDict

import numpy as np

import torch

from torch.utils.data import DataLoader
from torch.distributions import Normal

from eeyore.data import Iris
from eeyore.models import mlp
from eeyore.mcmc import MetropolisHastings, MALA, PowerPosteriorSampler

from timeit import default_timer as timer
from datetime import timedelta

In [2]:
## Load iris data

iris = Iris()
dataloader = DataLoader(iris, batch_size=150)

In [3]:
## Setup MLP model

hparams = mlp.Hyperparameters(dims=[4, 3, 3])
model = mlp.MLP(hparams=hparams)
model.prior = Normal(torch.zeros(model.num_params(), dtype=model.dtype), np.sqrt(3)*torch.ones(model.num_params(), dtype=model.dtype))

In [4]:
## Setup PowerPosteriorSampler

theta0 = model.prior.sample()
# per_chain_samplers = [['MALA', {'step': 0.025}], ['MALA', {'step': 0.025}], ['MALA', {'step': 0.025}]]
per_chain_samplers = 11 * [['MALA', {'step': 0.025}],]

In [5]:
sampler = PowerPosteriorSampler(model, theta0, dataloader, per_chain_samplers)

In [6]:
sampler.categoricals[4].probs

tensor([0.0484, 0.0798, 0.1315, 0.2168, 0.2168, 0.1315, 0.0798, 0.0484, 0.0293,
        0.0178])

In [7]:
torch.sum(sampler.categoricals[4].probs)

tensor(1.)

In [8]:
torch.log(sampler.categoricals[4].probs[sampler.from_events_to_seq(10, 4)]).item()

-4.028770446777344

In [9]:
sampler.categoricals[4].log_prob(torch.tensor(9))

tensor(-4.0288)

In [14]:
sampler.categoricals[4].log_prob(torch.tensor(sampler.from_events_to_seq(10, 4)))

tensor(-4.0288)

In [10]:
sampler.categoricals[4].sample()

tensor(5)

In [11]:
[sampler.categorical_sample(4) for _ in range(20)]

[8, 3, 2, 5, 8, 3, 1, 2, 5, 7, 5, 5, 3, 6, 2, 3, 2, 2, 5, 6]