-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FR Streaming MCMC interface for big models #2843
Comments
Hi @fritzo! I was searching for an issue and if this one's free I would like to try solving it. So as I understand the main point here would be to implement class StreamingMCMC:
def __init__(...):
self.incremental_mean = ...
self.incremental_variance = ...
# and the rest of statistics that will be used by 'summary' method
def run(self, *args, **kwargs):
...
for x, chain_id in self.sampler.run(*args, **kwargs):
num_samples += 1
self.incremental_mean += (x - self.incremental_mean) / num_samples
# ...and the rest of statistics
del x
...
def diagnostics(self):
...
def summary(self, prob=0.9):
# just returns computed incremental statistics Also in As described, Follow up question: Should It's seemingly straightforward, but I've just started looking at the source code. Are there any pitfalls that I should bear in mind? |
Hi @mtsokol that sounds great and I'm happy to provide any review and guidance. Your sketch looks good. The only difference I'd suggest would be for us to think hard about making a fully extensible interface for computing streaming statistics, so that users can easily stream other custom things. I was thinking with task 2 above to create a new module say from abc import ABC, abstractmethod
class StreamingStatistic(ABC):
"""Base class for streamable statistics"""
@abstractmethod
def update(self, sample: Dict[str, torch.Tensor]) -> None:
"""Update state from a single sample."""
raise NotImplementedError
@abstractmethod
def merge(self, other: StreamingStatistic) -> StreamingStatistic:
"""Combine two aggregate statistics, e.g. from different chains."""
assert type(self) == type(other)
raise NotImplementedError
@abstractmethod
def get(self) -> Dict[str, torch.Tensor]:
"""Return the aggregate statistic."""
raise NotImplementedError Together with a set of basic concrete statistics (see also pyro.ops.welford for implementation but non-general interface) class Count(StreamingStatistic): ...
class Mean(StreamingStatistic): ...
class MeanAndVariance(StreamingStatistic): ...
class MeanAndCovariance(StreamingStatistic): ...
class RHat(StreamingStatistic): ... And maybe a restriction to a subset of names class SubsetStatistic(StreamingStatistic):
def __init__(self, names : Set[str], base_stat: StreamingStatistic):
self.names = names
self.base_stat
def update(self, sample):
sample = {k: v for k, v in sample.items() if k in self.names}
self.base_stat.update(sample)
def get(self):
return self.base_stat.get() I think that might be enough of an interface, but we might want more details in the Then once we have basic statistics we can make your interface generic and extensible: class StreamingMCMC:
def __init__(..., statistics=None):
if statistics is None:
statistics = [Count(), MeanAndVariance()]
self._statistics = statistics
def run(self, *args, **kwargs):
...
for x, chain_id in self.sampler.run(*args, **kwargs):
num_samples += 1
for stat in self._statistics:
stat.update(x)
del x
...
def diagnostics(self):
...
def summary(self, prob=0.9):
# just returns computed incremental statistics What I'd really like is to be able to define custom statistics for a particular problem, e.g. saving a list of norms class ListOfNorms(StreamingStatistic):
def __init__(self):
self._lists = defaultdict(list)
def update(self, data):
for k, v in data.items():
self._lists[k].append(torch.linalg.norm(v.detach().reshape(-1)).item())
def get(self):
return dict(self._lists)
my_mcmc = StreamingMCMC(..., stats=[MeanAndVariance(), ListOfNorms()]) WDYT? |
Addressing your earlier questions:
Correct, most existing tests should be parametrized with @pytest.markparametrize("mcmc_cls", [MCMC, StreamingMCMC])
Almost. The main beneficiary here is large models which push against memory limits and therefore necessitate streaming rather than saving all samples in memory. And if you're pushing against memory limits, you'll want to avoid parallelizing and instead sequentially compute chains (which can itself be seen as a streaming operation). In practice yes most models that hit memory limits are run on GPU, but multicore CPU models can also be very performant.
|
@mtsokol would you want to work on this in parallel? Maybe you could implement the |
@fritzo thanks for guidance! Right now I'm looking at the current implementation and starting working on this. Sure! I can start working on Should I introduce some |
Feel free to implement an |
@fritzo After thinking about handling those streamed samples I wanted to ask a few more questions:
Follow up on the first question: If such a thing makes a performance difference (but I'm just wondering - it might be irrelevant) maybe instead of streaming each sample to statistics it can work in batches instead. E.g. introduce an additional argument |
@mtsokol answering your latest questions:
Lines 389 to 392 in 4a61ef2
|
Hi @fritzo! I was wondering what I can try to do next. As Apart from that I can definitely try:
Could you suggest to me a problem with a model that would be suitable for that? Also I can join new tutorial with your suggestion in the last bullet point in #2803 (comment) (showing how WDYT? This would be a documentation task and I was also looking for an implementation one. Have you got something that I can try? |
This issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say
StreamingMCMC
) with similar interface toMCMC
and still independent of kernel (using eitherHMC
orNUTS
) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support inStreamingMCMC
.Along with the new
StreamingMCMC
class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.Tasks (to be split into multiple PRs)
@mtsokol
StreamingMCMC
class with interface identical to MCMC (except disallowing parallel chains).MCMC
to parametrize over bothMCMC
andStreamingMCMC
StreamingMCMC
andMCMC
perform identical computations, up to numerical precisionStreamingMCMC
on a big model@fritzo
r_hat
to pyro.ops.streamingn_eff = ess
to pyro.ops.streamingThe text was updated successfully, but these errors were encountered: