13 changes: 7 additions & 6 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

import pymc as pm

from pymc.backends import init_traces
from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains
from pymc.backends import RunType, TraceOrBackend, init_traces
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
Expand Down Expand Up @@ -328,7 +328,7 @@ def sample(
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: Optional[BaseTrace] = None,
trace: Optional[TraceOrBackend] = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
Expand Down Expand Up @@ -609,13 +609,12 @@ def sample(
_check_start_shape(model, ip)

# Create trace backends for each chain
traces = init_traces(
run, traces = init_traces(
backend=trace,
chains=chains,
expected_length=draws + tune,
step=step,
var_dtypes={vn: v.dtype for vn, v in ip.items()},
var_shapes={vn: v.shape for vn, v in ip.items()},
initial_point=ip,
model=model,
)

Expand Down Expand Up @@ -690,6 +689,7 @@ def sample(
# Packaging, validating and returning the result was extracted
# into a function to make it easier to test and refactor.
return _sample_return(
run=run,
traces=traces,
tune=tune,
t_sampling=t_sampling,
Expand All @@ -704,6 +704,7 @@ def sample(

def _sample_return(
*,
run: Optional[RunType],
traces: Sequence[IBaseTrace],
tune: int,
t_sampling: float,
Expand Down
41 changes: 31 additions & 10 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ def stop_tuning(self):
self.tune = False


def flat_statname(sampler_idx: int, sname: str) -> str:
"""Get the flat-stats name for a samplers stat."""
return f"sampler_{sampler_idx}__{sname}"


def get_stats_dtypes_shapes_from_steps(
steps: Iterable[BlockedStep],
) -> Dict[str, Tuple[StatDtype, StatShape]]:
Expand All @@ -201,7 +206,7 @@ def get_stats_dtypes_shapes_from_steps(
result = {}
for s, step in enumerate(steps):
for sname, (dtype, shape) in step.stats_dtypes_shapes.items():
result[f"sampler_{s}__{sname}"] = (dtype, shape)
result[flat_statname(s, sname)] = (dtype, shape)
return result


Expand Down Expand Up @@ -262,10 +267,21 @@ class StatsBijection:

def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
# Keep a list of flat vs. original stat names
self._stat_groups: List[List[Tuple[str, str]]] = [
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
for s, names_dtypes in enumerate(sampler_stats_dtypes)
]
stat_groups = []
for s, names_dtypes in enumerate(sampler_stats_dtypes):
group = []
for statname, dtype in names_dtypes.items():
flatname = flat_statname(s, statname)
is_obj = np.dtype(dtype) == np.dtype(object)
group.append((flatname, statname, is_obj))
stat_groups.append(group)
self._stat_groups: List[List[Tuple[str, str, bool]]] = stat_groups
self.object_stats = {
fname: (s, sname)
for s, group in enumerate(self._stat_groups)
for fname, sname, is_obj in group
if is_obj
}

@property
def n_samplers(self) -> int:
Expand All @@ -275,9 +291,10 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
"""Combine stats dicts of multiple samplers into one dict."""
stats_dict = {}
for s, sts in enumerate(stats_list):
for statname, sval in sts.items():
sname = f"sampler_{s}__{statname}"
stats_dict[sname] = sval
for fname, sname, is_obj in self._stat_groups[s]:
if sname not in sts:
continue
stats_dict[fname] = sts[sname]
return stats_dict

def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
Expand All @@ -286,7 +303,11 @@ def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
The ``stats_dict`` can be a subset of all sampler stats.
"""
stats_list = []
for namemap in self._stat_groups:
d = {statname: stats_dict[sname] for sname, statname in namemap if sname in stats_dict}
for group in self._stat_groups:
d = {}
for fname, sname, is_obj in group:
if fname not in stats_dict:
continue
d[sname] = stats_dict[fname]
stats_list.append(d)
return stats_list
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ git+https://github.com/pymc-devs/pymc-sphinx-theme
h5py>=2.7
ipython>=7.16
jupyter-sphinx
mcbackend>=0.4.0
mypy==0.990
myst-nb
numdifftools>=0.9.40
Expand Down
305 changes: 305 additions & 0 deletions tests/backends/test_mcbackend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import arviz
import numpy as np
import pytest

import pymc as pm

from pymc.backends import init_traces
from pymc.step_methods.arraystep import ArrayStepShared

try:
import mcbackend as mcb

from mcbackend.npproto.utils import ndarray_to_numpy
except ImportError:
pytest.skip("Requires McBackend to be installed.")

from pymc.backends.mcbackend import (
ChainRecordAdapter,
find_data,
get_variables_and_point_fn,
make_runmeta_and_point_fn,
)


@pytest.fixture
def simple_model():
seconds = np.linspace(0, 5)
observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
with pm.Model(
coords={
"condition": ["A", "B", "C"],
}
) as pmodel:
x = pm.ConstantData("seconds", seconds, dims="time")
a = pm.Normal("scalar")
b = pm.Uniform("vector", dims="condition")
pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
pm.Bernoulli("integer", p=0.5)
obs = pm.MutableData("obs", observations, dims=("condition", "time"))
pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
return pmodel


def test_find_data(simple_model):
dvars = find_data(simple_model)
dvardict = {d.name: d for d in dvars}
assert set(dvardict) == {"seconds", "obs"}

secs = dvardict["seconds"]
assert isinstance(secs, mcb.DataVariable)
assert secs.dims == ["time"]
assert not secs.is_observed
np.testing.assert_array_equal(ndarray_to_numpy(secs.value), simple_model["seconds"].data)

obs = dvardict["obs"]
assert isinstance(obs, mcb.DataVariable)
assert obs.dims == ["condition", "time"]
assert obs.is_observed
np.testing.assert_array_equal(ndarray_to_numpy(obs.value), simple_model["obs"].get_value())


def test_find_data_skips_deterministics():
data = np.array([0, 1], dtype="float32")
with pm.Model() as pmodel:
a = pm.ConstantData("a", data, dims="item")
b = pm.Normal("b")
pm.Deterministic("c", a + b, dims="item")
assert "c" in pmodel.named_vars
dvars = find_data(pmodel)
assert len(dvars) == 1
assert dvars[0].name == "a"
assert dvars[0].dims == ["item"]
np.testing.assert_array_equal(ndarray_to_numpy(dvars[0].value), data)
assert not dvars[0].is_observed


def test_get_variables_and_point_fn(simple_model):
ip = simple_model.initial_point()
variables, point_fn = get_variables_and_point_fn(simple_model, ip)
assert isinstance(variables, list)
assert callable(point_fn)
vdict = {v.name: v for v in variables}
assert set(vdict) == {"integer", "scalar", "vector", "vector_interval__", "matrix"}
point = point_fn(ip)
assert len(point) == len(variables)
for v, p in zip(variables, point):
assert str(p.dtype) == v.dtype


def test_make_runmeta_and_point_fn(simple_model):
with simple_model:
step = pm.DEMetropolisZ()
rmeta, point_fn = make_runmeta_and_point_fn(
initial_point=simple_model.initial_point(),
step=step,
model=simple_model,
)
assert isinstance(rmeta, mcb.RunMeta)
assert callable(point_fn)
vars = {v.name: v for v in rmeta.variables}
assert set(vars.keys()) == {"scalar", "vector", "vector_interval__", "matrix", "integer"}
# NOTE: Technically the "vector" is deterministic, but from the user perspective it is not.
# This is merely a matter of which version of transformed variables should be traced.
assert not vars["vector"].is_deterministic
assert not vars["vector_interval__"].is_deterministic
assert vars["matrix"].is_deterministic
assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0])
pass


def test_init_traces(simple_model):
with simple_model:
step = pm.DEMetropolisZ()
run, traces = init_traces(
backend=mcb.NumPyBackend(),
chains=2,
expected_length=70,
step=step,
initial_point=simple_model.initial_point(),
model=simple_model,
)
assert isinstance(run, mcb.backends.numpy.NumPyRun)
assert isinstance(traces, list)
assert len(traces) == 2
assert isinstance(traces[0], ChainRecordAdapter)
assert isinstance(traces[0]._chain, mcb.backends.numpy.NumPyChain)
pass


class ToyStepper(ArrayStepShared):
stats_dtypes_shapes = {
"accepted": (bool, []),
"tune": (bool, []),
"s1": (np.float64, []),
}

def astep(self, *args, **kwargs):
raise NotImplementedError()


class ToyStepperWithOtherStats(ToyStepper):
stats_dtypes_shapes = {
"accepted": (bool, []),
"tune": (bool, []),
"s2": (np.float64, []),
}


class TestChainRecordAdapter:
def test_get_sampler_stats(self):
# Initialize a very simply toy model
N = 45
with pm.Model() as pmodel:
a = pm.Normal("a")
b = pm.Uniform("b")
c = pm.Deterministic("c", a + b)
ip = pmodel.initial_point()
shared = pm.make_shared_replacements(ip, [a, b], pmodel)
run, traces = init_traces(
backend=mcb.NumPyBackend(),
chains=1,
expected_length=N,
step=ToyStepper([a, b], shared),
initial_point=pmodel.initial_point(),
model=pmodel,
)
cra = traces[0]
assert isinstance(run, mcb.backends.numpy.NumPyRun)
assert isinstance(cra, ChainRecordAdapter)

# Simulate recording of draws and stats
rng = np.random.RandomState(2023)
for i in range(N):
draw = {"a": rng.normal(), "b_interval__": rng.normal()}
stats = [dict(tune=(i <= 5), s1=i, accepted=bool(rng.randint(0, 2)))]
cra.record(draw, stats)

# Check final state of the chain
assert len(cra) == N
# Variables b and c were calculated by the point function
draws_a = cra.get_values("a")
draws_b = cra.get_values("b")
draws_c = cra.get_values("c")
np.testing.assert_array_equal(draws_a + draws_b, draws_c)
i = np.random.randint(0, N)
point = cra.point(idx=i)
assert point["a"] == draws_a[i]
assert point["b"] == draws_b[i]
assert point["c"] == draws_c[i]

# Stats come in different shapes depending on the query
s1 = cra.get_sampler_stats("s1", sampler_idx=None, burn=3, thin=2)
assert s1.shape == (21,)
assert s1.dtype == np.dtype("float64")
np.testing.assert_array_equal(s1, np.arange(N)[3:None:2])

def test_get_sampler_stats_compound(self, caplog):
# Initialize a very simply toy model
N = 45
with pm.Model() as pmodel:
a = pm.Normal("a")
b = pm.Uniform("b")
c = pm.Deterministic("c", a + b)
ip = pmodel.initial_point()
shared_a = pm.make_shared_replacements(ip, [a], pmodel)
shared_b = pm.make_shared_replacements(ip, [b], pmodel)
stepA = ToyStepper([a], shared_a)
stepB = ToyStepperWithOtherStats([b], shared_b)
run, traces = init_traces(
backend=mcb.NumPyBackend(),
chains=1,
expected_length=N,
step=pm.CompoundStep([stepA, stepB]),
initial_point=pmodel.initial_point(),
model=pmodel,
)
cra = traces[0]
assert isinstance(cra, ChainRecordAdapter)

# Simulate recording of draws and stats
rng = np.random.RandomState(2023)
for i in range(N):
tune = i <= 5
draw = {"a": rng.normal(), "b_interval__": rng.normal()}
stats = [
dict(tune=tune, s1=i, accepted=bool(rng.randint(0, 2))),
dict(tune=tune, s2=i, accepted=bool(rng.randint(0, 2))),
]
cra.record(draw, stats)

# The 'accepted' stat was emitted by both samplers
assert cra.get_sampler_stats("accepted", sampler_idx=None).shape == (N, 2)
acpt_1 = cra.get_sampler_stats("accepted", sampler_idx=0, burn=3, thin=2)
acpt_2 = cra.get_sampler_stats("accepted", sampler_idx=1, burn=3, thin=2)
assert acpt_1.shape == (21,) # (N-3)/2
assert not np.array_equal(acpt_1, acpt_2)

# s1 and s2 were sampler specific
# they are squeezed into vectors, but warnings are logged at DEBUG level
with caplog.at_level(logging.DEBUG, logger="pymc"):
s1 = cra.get_sampler_stats("s1", burn=10)
assert s1.shape == (35,)
assert s1.dtype == np.dtype("float64")
s2 = cra.get_sampler_stats("s2", thin=5)
assert s2.shape == (9,) # N/5
assert s2.dtype == np.dtype("float64")
assert any("'s1' was not recorded by all samplers" in r.message for r in caplog.records)

with pytest.raises(KeyError, match="No stat"):
cra.get_sampler_stats("notastat")


class TestMcBackendSampling:
@pytest.mark.parametrize("discard_warmup", [False, True])
def test_return_multitrace(self, simple_model, discard_warmup):
with simple_model:
mtrace = pm.sample(
trace=mcb.NumPyBackend(),
tune=5,
draws=7,
cores=1,
chains=3,
step=pm.Metropolis(),
discard_tuned_samples=discard_warmup,
return_inferencedata=False,
)
assert isinstance(mtrace, pm.backends.base.MultiTrace)
tune = mtrace._straces[0].get_sampler_stats("tune")
assert isinstance(tune, np.ndarray)
if discard_warmup:
assert tune.shape == (7, 3)
else:
assert tune.shape == (12, 3)
pass

@pytest.mark.parametrize("cores", [1, 3])
def test_return_inferencedata(self, simple_model, cores):
with simple_model:
idata = pm.sample(
trace=mcb.NumPyBackend(),
tune=5,
draws=7,
cores=cores,
chains=3,
discard_tuned_samples=False,
)
assert isinstance(idata, arviz.InferenceData)
assert idata.warmup_posterior.sizes["draw"] == 5
assert idata.posterior.sizes["draw"] == 7
pass
7 changes: 7 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def test_sample_return_lengths(self):

# MultiTrace without warmup
mtrace_pst = pm.sampling.mcmc._sample_return(
run=None,
traces=traces,
tune=50,
t_sampling=123.4,
Expand All @@ -380,6 +381,7 @@ def test_sample_return_lengths(self):

# InferenceData with warmup
idata_w = pm.sampling.mcmc._sample_return(
run=None,
traces=traces,
tune=50,
t_sampling=123.4,
Expand All @@ -398,6 +400,7 @@ def test_sample_return_lengths(self):

# InferenceData without warmup
idata = pm.sampling.mcmc._sample_return(
run=None,
traces=traces,
tune=50,
t_sampling=123.4,
Expand Down Expand Up @@ -463,6 +466,10 @@ def test_keep_warning_stat_setting(self, keep_warning_stat):
# This tests flattens so we don't have to be exact in accessing (non-)squeezed items.
# Also see https://github.com/pymc-devs/pymc/issues/6207.
warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten())
assert warn_objs
if isinstance(warn_objs[0], np.ndarray):
# Squeeze warning stats. See https://github.com/pymc-devs/pymc/issues/6207
warn_objs = [a.tolist() for a in warn_objs]
assert any(isinstance(w, SamplerWarning) for w in warn_objs)
assert any("Asteroid" in w.message for w in warn_objs)
else:
Expand Down
8 changes: 5 additions & 3 deletions tests/step_methods/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,22 @@ def test_flatten_steps(self):
def test_stats_bijection(self):
step_stats_dtypes = [
{"a": float, "b": int},
{"a": float, "c": int},
{"a": float, "c": Warning},
]
bij = StatsBijection(step_stats_dtypes)
assert bij.object_stats == {"sampler_1__c": (1, "c")}
assert bij.n_samplers == 2
w = Warning("hmm")
stats_l = [
dict(a=1.5, b=3),
dict(a=2.5, c=4),
dict(a=2.5, c=w),
]
stats_d = bij.map(stats_l)
assert isinstance(stats_d, dict)
assert stats_d["sampler_0__a"] == 1.5
assert stats_d["sampler_0__b"] == 3
assert stats_d["sampler_1__a"] == 2.5
assert stats_d["sampler_1__c"] == 4
assert stats_d["sampler_1__c"] == w
rev = bij.rmap(stats_d)
assert isinstance(rev, list)
assert len(rev) == len(stats_l)
Expand Down