Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for the McBackend adapter #6835

Merged
merged 3 commits into from Jul 24, 2023

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jul 20, 2023

What is this PR about?
This fixes two bugs that were reported by @thelogicalgrammar in pymc-devs/mcbackend#93.

The first was a leftover of a previous refactoring of sampler stats, and the way how the "tune" stat is handled in PyMC.

For the record: Our current approach of passing information about tuning/not-tuning via sample stats is brittle. It's a per-sampler thing, but in practice it's actually a per-iteration thing from the PyMC, ArviZ and McBackend perspective. IMO we should consider taking it out of the step method attributes and stats, and instead pass tune as a parameter to the astep method.

However, this is out of scope for this PR, so I aligned with the current design and added "tune" to two step methods that didn't have it, much like Ricardo did for Slice in fee9a02.

The second bug was just a small dtype thing for pickled stat objects.
I tested it locally with a ClickHouse server:

def test_issue_93_b():
    seconds = np.linspace(0, 5)
    observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
    with pm.Model(
        coords={
            "condition": ["A", "B", "C"],
        }
    ) as pmodel:
        x = pm.ConstantData("seconds", seconds, dims="time")
        a = pm.Normal("scalar")
        b = pm.Uniform("vector", dims="condition")
        pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
        obs = pm.MutableData("obs", observations, dims=("condition", "time"))
        pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))

    ch_client = clickhouse_driver.Client("localhost")
    backend = ClickHouseBackend(ch_client)
    with pmodel:
        pm.sample(trace=backend, tune=5, draws=7, discard_tuned_samples=False)
    assert idata.warmup_posterior.sizes["draw"] == 5
    assert idata.posterior.sizes["draw"] == 7
    pass

Checklist

Major / Breaking Changes

  • None

New features

  • None

Bugfixes

  • Two bugs in McBackend adapter were fixed

Documentation

  • None

Maintenance

  • All of PyMC's step methods now emit a "tune" stat

馃摎 Documentation preview 馃摎: https://pymc--6835.org.readthedocs.build/en/6835/

@codecov
Copy link

codecov bot commented Jul 20, 2023

Codecov Report

Merging #6835 (33364d5) into main (82c6318) will decrease coverage by 2.55%.
The diff coverage is 92.59%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6835      +/-   ##
==========================================
- Coverage   92.04%   89.49%   -2.55%     
==========================================
  Files          95       95              
  Lines       16298    16322      +24     
==========================================
- Hits        15001    14607     -394     
- Misses       1297     1715     +418     
Impacted Files Coverage 螖
pymc/step_methods/compound.py 97.17% <75.00%> (-0.46%) 猬囷笍
pymc/backends/mcbackend.py 100.00% <100.00%> (酶)
pymc/step_methods/metropolis.py 86.92% <100.00%> (+0.33%) 猬嗭笍

... and 9 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what's going on with the "object" vs "str" dtype. Otherwise LGTM

@michaelosthege
Copy link
Member Author

michaelosthege commented Jul 24, 2023

Not sure what's going on with the "object" vs "str" dtype. Otherwise LGTM

thanks for the review!

To clarify this part: Sampler warnings are SamplerWarning objects, but as such they can't be stored in all backends. Therefore, the ChainRecordAdapter pickles them to a string, which can be stored in HDF5/ClickHouse/whatever.

That's why they must be announced as str-typed when setting up the backend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants