Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FR Streaming MCMC interface for big models #2843

Open
3 of 7 tasks
fritzo opened this issue May 14, 2021 · 9 comments
Open
3 of 7 tasks

FR Streaming MCMC interface for big models #2843

fritzo opened this issue May 14, 2021 · 9 comments
Assignees
Labels
enhancement help wanted Issues suitable for, and inviting external contributions

Comments

@fritzo
Copy link
Member

fritzo commented May 14, 2021

This issue proposes a streaming architecture for MCMC on models with large memory footprint.

The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.

@fehiepsi suggested creating a new MCMC class (say StreamingMCMC) with similar interface to MCMC and still independent of kernel (using either HMC or NUTS) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC.

Along with the new StreamingMCMC class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.

Tasks (to be split into multiple PRs)

@mtsokol

  • StreamingMCMC class #2857 Create a StreamingMCMC class with interface identical to MCMC (except disallowing parallel chains).
  • StreamingMCMC class #2857 Generalize unit tests of MCMC to parametrize over both MCMC and StreamingMCMC
  • Add some tests ensuring StreamingMCMC and MCMC perform identical computations, up to numerical precision
  • Create a tutorial using StreamingMCMC on a big model

@fritzo

@fritzo fritzo assigned fritzo and unassigned fritzo May 14, 2021
@fritzo fritzo added the help wanted Issues suitable for, and inviting external contributions label May 17, 2021
@mtsokol
Copy link
Contributor

mtsokol commented Jun 1, 2021

Hi @fritzo!

I was searching for an issue and if this one's free I would like to try solving it.

So as I understand the main point here would be to implement StreamingMCMC that doesn't contain get_samples method and keeps in its state incrementally updated statistics (if all of them can be incrementally computed, can they?). Something like this:

class StreamingMCMC:
    def __init__(...):
        self.incremental_mean = ...
        self.incremental_variance = ...
        # and the rest of statistics that will be used by 'summary' method

    def run(self, *args, **kwargs):
        ...
        for x, chain_id in self.sampler.run(*args, **kwargs):
            num_samples += 1
            self.incremental_mean += (x - self.incremental_mean) / num_samples
            # ...and the rest of statistics
            del x
            ...

    def diagnostics(self):
        ...

    def summary(self, prob=0.9):
        # just returns computed incremental statistics

Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.

As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)

Follow up question: Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?


It's seemingly straightforward, but I've just started looking at the source code. Are there any pitfalls that I should bear in mind?

@fritzo
Copy link
Member Author

fritzo commented Jun 1, 2021

Hi @mtsokol that sounds great and I'm happy to provide any review and guidance.

Your sketch looks good. The only difference I'd suggest would be for us to think hard about making a fully extensible interface for computing streaming statistics, so that users can easily stream other custom things. I was thinking with task 2 above to create a new module say pyro.ops.streaming with a class hierarchy of basic streamable statistics

from abc import ABC, abstractmethod

class StreamingStatistic(ABC):
    """Base class for streamable statistics"""
    @abstractmethod
    def update(self, sample: Dict[str, torch.Tensor]) -> None:
        """Update state from a single sample."""
        raise NotImplementedError
    @abstractmethod
    def merge(self, other: StreamingStatistic) -> StreamingStatistic:
        """Combine two aggregate statistics, e.g. from different chains."""
        assert type(self) == type(other)
        raise NotImplementedError
    @abstractmethod
    def get(self) -> Dict[str, torch.Tensor]:
        """Return the aggregate statistic."""
        raise NotImplementedError

Together with a set of basic concrete statistics (see also pyro.ops.welford for implementation but non-general interface)

class Count(StreamingStatistic): ...
class Mean(StreamingStatistic): ...
class MeanAndVariance(StreamingStatistic): ...
class MeanAndCovariance(StreamingStatistic): ...
class RHat(StreamingStatistic): ...

And maybe a restriction to a subset of names

class SubsetStatistic(StreamingStatistic):
    def __init__(self, names : Set[str], base_stat: StreamingStatistic):
        self.names = names
        self.base_stat
    def update(self, sample):
        sample = {k: v for k, v in sample.items() if k in self.names}
        self.base_stat.update(sample)
    def get(self):
        return self.base_stat.get()

I think that might be enough of an interface, but we might want more details in the __init__ methods.

Then once we have basic statistics we can make your interface generic and extensible:

class StreamingMCMC:
    def __init__(..., statistics=None):
        if statistics is None:
            statistics = [Count(), MeanAndVariance()]
        self._statistics = statistics

    def run(self, *args, **kwargs):
        ...
        for x, chain_id in self.sampler.run(*args, **kwargs):
            num_samples += 1
            for stat in self._statistics:
                stat.update(x)
            del x
            ...

    def diagnostics(self):
        ...

    def summary(self, prob=0.9):
        # just returns computed incremental statistics

What I'd really like is to be able to define custom statistics for a particular problem, e.g. saving a list of norms

class ListOfNorms(StreamingStatistic):
    def __init__(self):
        self._lists = defaultdict(list)
    def update(self, data):
        for k, v in data.items():
            self._lists[k].append(torch.linalg.norm(v.detach().reshape(-1)).item())
    def get(self):
        return dict(self._lists)

my_mcmc = StreamingMCMC(..., stats=[MeanAndVariance(), ListOfNorms()])

WDYT?

@fritzo
Copy link
Member Author

fritzo commented Jun 1, 2021

Addressing your earlier questions:

Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.

Correct, most existing tests should be parametrized with

@pytest.markparametrize("mcmc_cls", [MCMC, StreamingMCMC])

As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)

Almost. The main beneficiary here is large models which push against memory limits and therefore necessitate streaming rather than saving all samples in memory. And if you're pushing against memory limits, you'll want to avoid parallelizing and instead sequentially compute chains (which can itself be seen as a streaming operation). In practice yes most models that hit memory limits are run on GPU, but multicore CPU models can also be very performant.

Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?

StreamingMCMC should still support num_chains > 1 (which is valuable for determining convergence), but should compute them sequentially.

@fritzo
Copy link
Member Author

fritzo commented Jun 2, 2021

@mtsokol would you want to work on this in parallel? Maybe you could implement the StreamingMCMC class using hand-coded statistics, I could implement a basic pyro.ops.streaming module, and over the course of a few PRs we could meet in the middle?

@mtsokol
Copy link
Contributor

mtsokol commented Jun 2, 2021

@fritzo thanks for guidance! Right now I'm looking at the current implementation and starting working on this.
This abstraction with StreamingStatistic is sound to me. StreamingMCMC will only iterate and call method on passed objects implementing that interface.

Sure! I can start working on StreamingMCMC and already follow StreamingStatistic notion. When your RP is ready I will adjust my implementation.

Should I introduce some AbstractMCMC interface that existing MCMC and StreamingMCMC will implement?

@fritzo
Copy link
Member Author

fritzo commented Jun 2, 2021

Feel free to implement an AbstractMCMC interface if you like. I defer to your design judgement here.

@mtsokol
Copy link
Contributor

mtsokol commented Jun 2, 2021

@fritzo After thinking about handling those streamed samples I wanted to ask a few more questions:

  1. So right now samples are being yield by sampler and each one is appended to the right chain list by z_flat_acc[chain_id].append(x_cloned). Then we do reshaping to get rid of the last dimension and have dict entries instead in that place (based on yielded structure). Then we perform element-wise transform (with self.transforms) (transform operation is determined by dict entry).
    Streaming based approach would go as follows: Again each sample is being yield by the sampler. The sample is used to construct a dict (based on yielded structure). Then that single dict is transformed (with self.transforms) and then the sample is fed to each statistic via update(self, sample: Dict[str, torch.Tensor]). (So each single sample will result in constructing a new dict, is that OK?). WDYT?

  2. Should StreamingStatistic update be chain_id-aware? Like update(self, chain_id: int, sample: Dict[str, torch.Tensor]) so that it can compute chain related diagnostics and support group_by_chain argument?

  3. Why do we need to clone: x_cloned = x.clone() when num_chains > 1?


Follow up on the first question: If such a thing makes a performance difference (but I'm just wondering - it might be irrelevant) maybe instead of streaming each sample to statistics it can work in batches instead. E.g. introduce an additional argument batch_size=100 so StreamingMCMC would wait until it aggregates 100 samples, then constructs that dict and performs transformations and feeds the whole batch to statistics. (But maybe constructing a dict for each sample and transforming each sample separately isn't really an overhead - with ready implementation I can run memory and time measurements) WDYT?

@fritzo
Copy link
Member Author

fritzo commented Jun 2, 2021

@mtsokol answering your latest questions:

  1. tl;dr keep it simple.
    I do not foresee a performance hit here: it is cheap to create dicts, and StreamingMCMC will typically be used with large memory-bound models with huge tensors, where the python overhead is negligible. For this same reason I think we should avoid batching since that increases memory overhead. (In fact I suspect the bottleneck will be in pyro.ops.streaming where we may need to refactor to perform tensor operations in-place).
  2. Yes, I believe we will want to compute both per-chain and total-aggregated statistics. I have added a .merge() operation in Implement pyro.ops.streaming module #2856 to make this easy for you. The main motivation is to compute cross-chain statistics like r_hat.
  3. It looks like the cloning is explained earlier in the file. I would recommend keeping that logic.

pyro/pyro/infer/mcmc/api.py

Lines 389 to 392 in 4a61ef2

# XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer"
# at https://github.com/pytorch/pytorch/issues/10375
# This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which
# requires_grad", which happens with `jit_compile` under PyTorch 1.7

@mtsokol
Copy link
Contributor

mtsokol commented Jun 24, 2021

Hi @fritzo!

I was wondering what I can try to do next.

As Add r_hat to pyro.ops.streaming is completed I tried n_eff = ess for streaming but after short inspection of current implementation it looks undoable to me (as it requires e.g. those lags).

Apart from that I can definitely try:

Create a tutorial using StreamingMCMC on a big model

Could you suggest to me a problem with a model that would be suitable for that?

Also I can join new tutorial with your suggestion in the last bullet point in #2803 (comment) (showing how Predictive can be interchanged with poutine methods).

WDYT?


This would be a documentation task and I was also looking for an implementation one. Have you got something that I can try?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement help wanted Issues suitable for, and inviting external contributions
Projects
None yet
Development

No branches or pull requests

2 participants