Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pyro.ops.streaming module #2856

Merged
merged 12 commits into from
Jun 7, 2021
Merged

Implement pyro.ops.streaming module #2856

merged 12 commits into from
Jun 7, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jun 2, 2021

Addresses #2843

This implements a new module pyro.ops.streaming to streamingly track various statistics. The first intended use case is the planned StreamingMCMC class which will track statistics rather than store samples. There are other potential uses in high-dimensional inference, e.g. recording statistics of gradients during SVI and computing sample moments from predictive when the samples don't fit in memory.

Design choices

The two basic operations are .update() and .get(). The third operation .merge() will be useful for multiple-chain MCMC and computing things like rhat.

I have restricted to the data type to dictionaries of tensors, which is the basic datatype in pyro.infer.mcmc and in much of NumPyro. We could easily generalize this to pytrees by adding classes StatsOfList and StatsOfTuple.

Tested

  • tested commutativity of update-get
  • tested commutativity and associativity of update-merge-get
  • ran mypy locally

@fritzo fritzo added awaiting review and removed WIP labels Jun 2, 2021
@fritzo
Copy link
Member Author

fritzo commented Jun 2, 2021

@eb8680 I've added you as a reviewer because these streaming classes create a new semigroup abstraction and you're the resident algebra expert.

@fritzo
Copy link
Member Author

fritzo commented Jun 2, 2021

@mtsokol I believe you can use something like the following statistics in #2843:

from pyro.ops.streaming import CountMeanVariance, StatsOfDict

...
stats = StatsOfDict(default=CountMeanVariance)
for mcmc_sample in ...:  # learning loop
    stats.update({
        name: transformed_sample for name, transformed_sample in mcmc_sample.items()
    })
result = stats.get()

Let me know if it looks like you'll need any changes to this PR.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat API! I'm a little confused by some of the types. Did you try running mypy locally?

pyro/ops/streaming.py Outdated Show resolved Hide resolved
self.count += 1

def merge(self, other: "CountStats"):
assert isinstance(other, type(self))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these assertions should no longer be necessary with type hints

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point more generally. However I'd like to argue that we should include both type hints and assertions until all common tools can leverage type hints. My reasoning is that I'd really like to catch errors as early as possible, e.g. when users (like me) are working in a jupyter notebook. I think until Jupyter dynamically checks types while editing we'll want extra guard rails especially for tricky interfaces like this.

"""
def __init__(
self,
types: Dict[object, Type[StreamingStats]] = {},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be strengthened to Dict[str, Type[StreamingStats]]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strengthened to Hashable, but I think we do want to support e.g. integer keys among chains.

pyro/ops/streaming.py Show resolved Hide resolved
pyro/ops/streaming.py Outdated Show resolved Hide resolved
pyro/ops/streaming.py Show resolved Hide resolved
@fritzo
Copy link
Member Author

fritzo commented Jun 3, 2021

@eb8680 thanks for reviewing!

I'm a little confused by some of the types. Did you try running mypy locally?

Sorry, I didn't run mypy locally, and some of the types are stale after refactoring. I'll fix... UPDATE ...fixed and ran mypy locally.

@mtsokol
Copy link
Contributor

mtsokol commented Jun 3, 2021

Current #2857 draft isn't chain-aware and I'm wondering about it. It can be either handled by pyro.ops.streaming, e.g.

class CountMeanStats(StreamingStats):
    def __init__(self, num_chains=1):
        self.counts = [0] * num_chains
        ...

    def update(self, sample, chain_index=0):
        ...

    def get(self, group_by_chain=True):
        # we can sum across chains

so the update in StreamingMCMC would be easy:

self._statistics.update({
    name: transformed_sample for name, transformed_sample in z_acc.items()
}, chain_index)

Otherwise it can be handled by StreamingMCMC with e.g. separate CountMeanStats for each chain that can be returned or merged into one if group_by_chain=False somewhere in summary.
WDYT?

@fritzo
Copy link
Member Author

fritzo commented Jun 3, 2021

@mtsokol I think it's best to keep chain logic in the StreamingMCMC class so as to keep StreamingStats subclasses as simple as possible (and hence easy to extend by creating new subclasses). However I think with the latest couple commits you can easily separate by chain by making either a nested StatsOfDict or using keys of the form (chain_id, site["name"]). Let me know if you have any ideas about changing the StreamingStats interface to make this easier.

eb8680
eb8680 previously approved these changes Jun 4, 2021
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after merge conflicts are resolved

@fritzo
Copy link
Member Author

fritzo commented Jun 4, 2021

Thanks for reviewing @eb8680! Looks like I'll be using this right away in my mutation models 😄

@eb8680 eb8680 merged commit 9bcaa38 into dev Jun 7, 2021
@eb8680 eb8680 deleted the ops-streaming branch June 7, 2021 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants