Skip to content

Commit

Permalink
Update and fix docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Mar 26, 2024
1 parent 260e528 commit 5030d41
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
14 changes: 14 additions & 0 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def _guess_max_plate_nesting(model, args, kwargs):


class _predictiveResults(NamedTuple):
"""
Return value of call to ``_predictive`` and ``_predictive_sequential``.
"""

samples: dict
trace: Union[Trace, List[Trace]]

Expand Down Expand Up @@ -313,6 +317,10 @@ def get_vectorized_trace(self, *args, **kwargs):


class WeighedPredictiveResults(NamedTuple):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
Expand Down Expand Up @@ -351,6 +359,9 @@ def call(self, *args, **kwargs):
Method `.call` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.
Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
result = self.forward(*args, **kwargs)
return WeighedPredictiveResults(
Expand All @@ -365,6 +376,9 @@ def forward(self, *args, **kwargs):
Method `.forward` that is backwards compatible with the same method found in :class:`Predictive`
but can be called with an additional keyword argument `model_guide`
which is the model used to create and optimize the guide.
Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample
weights ``.log_weights``.
"""
model_guide = kwargs.pop("model_guide", self.model)
return_sites = self.return_sites
Expand Down
19 changes: 11 additions & 8 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,17 @@ def weighed_quantile(
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.
Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
**Example:**
.. doctest::
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
Expand Down

0 comments on commit 5030d41

Please sign in to comment.