Skip to content

Commit

Permalink
Initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jun 3, 2021
1 parent 7a4b568 commit b8927be
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 26 deletions.
145 changes: 120 additions & 25 deletions pyro/infer/mcmc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- minimal memory consumption with multiprocessing and CUDA.
"""

from abc import ABC
import json
import logging
import queue
Expand All @@ -33,6 +34,7 @@
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.util import diagnostics, print_summary
from pyro.util import optional
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict

MAX_SEED = 2**32 - 1

Expand Down Expand Up @@ -257,7 +259,42 @@ def run(self, *args, **kwargs):
self.terminate(terminate_workers=exc_raised)


class MCMC:
class AbstractMCMC(ABC):
"""
Base class for MCMC methods.
"""
def __init__(self, kernel, num_chains, transforms):
self.kernel = kernel
self.num_chains = num_chains
self.transforms = transforms

def _set_transforms(self, *args, **kwargs):
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}

def _validate_kernel(self, initial_params):
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")

def _validate_initial_params(self, initial_params):
for v in initial_params.values():
if v.shape[0] != self.num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")


class MCMC(AbstractMCMC):
"""
Wrapper class for Markov Chain Monte Carlo algorithms. Specific MCMC algorithms
are TraceKernel instances and need to be supplied as a ``kernel`` argument
Expand Down Expand Up @@ -307,28 +344,21 @@ class MCMC:
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
num_chains=1, hook_fn=None, mp_context=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.kernel = kernel
self.transforms = transforms
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if save_params is not None:
kernel.save_params = save_params
if isinstance(self.kernel, (HMC, NUTS)) and self.kernel.potential_fn is not None:
if initial_params is None:
raise ValueError("Must provide valid initial parameters to begin sampling"
" when using `potential_fn` in HMC/NUTS kernel.")
self._validate_kernel(initial_params)
parallel = False
if num_chains > 1:
# check that initial_params is different for each chain
if initial_params:
for v in initial_params.values():
if v.shape[0] != num_chains:
raise ValueError("The leading dimension of tensors in `initial_params` "
"must match the number of chains.")
self._validate_initial_params(initial_params)
# FIXME: probably we want to use "spawn" method by default to avoid the error
# CUDA initialization error https://github.com/pytorch/pytorch/issues/2517
# even that we run MCMC in CPU.
Expand All @@ -348,10 +378,7 @@ def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}

self.num_chains = num_chains
self._diagnostics = [None] * num_chains

if parallel:
self.sampler = _MultiSampler(kernel, num_samples, self.warmup_steps, num_chains, mp_context,
disable_progbar, initial_params=initial_params, hook=hook_fn)
Expand Down Expand Up @@ -422,17 +449,7 @@ def model(data):
# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
# Use `kernel.transforms` when available
if getattr(self.kernel, "transforms", None) is not None:
self.transforms = self.kernel.transforms
# Else, get transforms from model (e.g. in multiprocessing).
elif self.kernel.model:
warmup_steps = 0
self.kernel.setup(warmup_steps, *args, **kwargs)
self.transforms = self.kernel.transforms
# Assign default value
else:
self.transforms = {}
self._set_transforms(*args, **kwargs)

# transform samples back to constrained space
for name, z in z_acc.items():
Expand Down Expand Up @@ -496,3 +513,81 @@ def summary(self, prob=0.9):
if 'divergences' in self._diagnostics[0]:
print("Number of divergences: {}".format(
sum([len(self._diagnostics[i]['divergences']) for i in range(self.num_chains)])))


class StreamingMCMC(AbstractMCMC):
def __init__(self, kernel, num_samples, warmup_steps=None, initial_params=None,
statistics=None, num_chains=1, hook_fn=None, disable_progbar=False,
disable_validation=True, transforms=None, save_params=None):
super().__init__(kernel, num_chains, transforms)
self.warmup_steps = num_samples if warmup_steps is None else warmup_steps # Stan
self.num_samples = num_samples
self.disable_validation = disable_validation
self._samples = None
self._args = None
self._kwargs = None
if statistics is None:
statistics = StatsOfDict(default=CountMeanVarianceStats)
self._statistics = statistics
if save_params is not None:
kernel.save_params = save_params
self._validate_kernel(initial_params)
if num_chains > 1:
if initial_params:
self._validate_initial_params(initial_params)
else:
if initial_params:
initial_params = {k: v.unsqueeze(0) for k, v in initial_params.items()}
self._diagnostics = [None] * num_chains
self.sampler = _UnarySampler(kernel, num_samples, self.warmup_steps, num_chains, disable_progbar,
initial_params=initial_params, hook=hook_fn)

@poutine.block
def run(self, *args, **kwargs):
self._args, self._kwargs = args, kwargs
num_samples = [0] * self.num_chains

# If transforms is not explicitly provided, infer automatically using
# model args, kwargs.
if self.transforms is None:
self._set_transforms(*args, **kwargs)

with optional(pyro.validation_enabled(not self.disable_validation),
self.disable_validation is not None):
args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
for x, chain_id in self.sampler.run(*args, **kwargs):
if num_samples[chain_id] == 0:
num_samples[chain_id] += 1
z_structure = x
elif num_samples[chain_id] == self.num_samples + 1:
self._diagnostics[chain_id] = x
else:
num_samples[chain_id] += 1
if self.num_chains > 1:
x_cloned = x.clone()
del x
else:
x_cloned = x

# unpack latent
pos = 0
z_acc = z_structure.copy()
for k in sorted(z_structure):
shape = z_structure[k]
next_pos = pos + shape.numel()
z_acc[k] = x_cloned[pos:next_pos].reshape(shape)
pos = next_pos

for name, z in z_acc.items():
if name in self.transforms:
z_acc[name] = self.transforms[name].inv(z)

self._statistics.update({
name: transformed_sample for name, transformed_sample in z_acc.items()
})

# terminate the sampler (shut down worker processes)
self.sampler.terminate(True)

def summary(self, prob=0.9):
return self._statistics.get()
19 changes: 18 additions & 1 deletion tests/infer/mcmc/test_mcmc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.mcmc import HMC, NUTS
from pyro.infer.mcmc.api import MCMC, _MultiSampler, _UnarySampler
from pyro.infer.mcmc.api import MCMC, StreamingMCMC, _MultiSampler, _UnarySampler
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
from pyro.util import optional
from pyro.ops.streaming import CountMeanVarianceStats, StatsOfDict, CountStats
from tests.common import assert_close


Expand Down Expand Up @@ -73,6 +74,22 @@ def normal_normal_model(data):
return y


@pytest.mark.parametrize("mcmc_cls", [StreamingMCMC])
@pytest.mark.parametrize('num_chains', [1, 2])
@pytest.mark.filterwarnings("ignore:num_chains")
def test_mcmc_summary(mcmc_cls, num_chains):
num_samples = 2000
data = torch.tensor([1.0])
initial_params, _, transforms, _ = initialize_model(normal_normal_model, model_args=(data,),
num_chains=num_chains)
kernel = PriorKernel(normal_normal_model)
mcmc = StreamingMCMC(kernel=kernel, num_samples=num_samples, warmup_steps=100,
statistics=StatsOfDict(default=CountMeanVarianceStats),
num_chains=num_chains, initial_params=initial_params, transforms=transforms)
mcmc.run(data)
print(mcmc.summary()) # TODO Draft test


@pytest.mark.parametrize('num_draws', [None, 1800, 2200])
@pytest.mark.parametrize('group_by_chain', [False, True])
@pytest.mark.parametrize('num_chains', [1, 2])
Expand Down

0 comments on commit b8927be

Please sign in to comment.