Skip to content

Commit

Permalink
Dependency tracking tutorial (#3197)
Browse files Browse the repository at this point in the history
* Dependency tracking tutorial

* fix typo

* concenpt -> concept; equals to -> equals
  • Loading branch information
ordabayevy committed Apr 20, 2023
1 parent 1311b9e commit dd4e0f8
Showing 1 changed file with 47 additions and 7 deletions.
54 changes: 47 additions & 7 deletions tutorial/source/svi_part_iii.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"svi = SVI(model, guide, optimizer, TraceGraph_ELBO())\n",
"```\n",
"\n",
"Note that leveraging this dependency information takes extra computations, so `TraceGraph_ELBO` should only be used in the case where your model has non-reparameterizable random variables; in most applications `Trace_ELBO` suffices."
"Note that leveraging this dependency information might have small computation overhead, so `TraceGraph_ELBO` should only be used in the case where your model has non-reparameterizable random variables; in most applications `Trace_ELBO` suffices."
]
},
{
Expand Down Expand Up @@ -137,7 +137,7 @@
" ks = pyro.sample(\"k\", dist.Categorical(probs))\n",
" pyro.sample(\"obs\", dist.Normal(locs[ks], scale),\n",
" obs=data)\n",
"``` \n",
"```\n",
"\n",
"That's all there is to it."
]
Expand All @@ -148,8 +148,48 @@
"source": [
"### Aside: Dependency tracking in Pyro\n",
"\n",
"Finally, a word about dependency tracking. Tracking dependency within a stochastic function that includes arbitrary Python code is a bit tricky. The approach currently implemented in Pyro is analogous to the one used in WebPPL (cf. reference [5]). Briefly, a conservative notion of dependency is used that relies on sequential ordering. If random variable ${\\bf z}_2$ follows ${\\bf z}_1$ in a given stochastic function then ${\\bf z}_2$ _may be_ dependent on ${\\bf z}_1$ and therefore _is_ assumed to be dependent. To mitigate the overly coarse conclusions that can be drawn by this kind of dependency tracking, Pyro includes constructs for declaring things as independent, namely `plate` and `markov` ([see the previous tutorial](svi_part_ii.ipynb)). For use cases with non-reparameterizable variables, it is therefore important for the user to make use of these constructs (when applicable) to take full advantage of the variance reduction provided by `SVI`. In some cases it may also pay to consider reordering random variables within a stochastic function (if possible).\n",
"Finally, a word about dependency tracking. Pyro uses the concept of provenance for tracking dependency within a stochastic function that includes arbitrary Python code (see reference [5]). In the programming language theory, the provenance of a variable refers to the history of variables or computations that contributed to its value. The simple example below demonstrates how provenance is tracked through PyTorch ops in Pyro, where provenance is a user-defined frozenset of objects:\n",
"\n",
"```python\n",
"from pyro.ops.provenance import get_provenance, track_provenance\n",
"\n",
"a = track_provenance(torch.randn(3), frozenset({\"a\"}))\n",
"b = track_provenance(torch.randn(3), frozenset({\"b\"}))\n",
"c = torch.randn(3) # no provenance information\n",
"\n",
"# For a unary operation, the provenance of the output tensor\n",
"# equals the provenace of the input tensor\n",
"assert get_provenance(a.exp()) == frozenset({\"a\"})\n",
"# In general, the provenance of the output tensors of any op\n",
"# is the union of provenances of input tensors.\n",
"assert get_provenance(a * (b + c)) == frozenset({\"a\", \"b\"})\n",
"``` \n",
"\n",
"This concept is utilized by `TraceGraph_ELBO` to trace the fine-grained dynamic dependency information on _non-reparameterizable_ random variables through intermediate computations as they come together to form a log-likelihood. Internally, non-reparameterizable sample sites are tracked using [TrackNonReparam](https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.tracegraph_elbo.TrackNonReparam) messenger:\n",
"\n",
"```python\n",
"def model():\n",
" probs_a = torch.tensor([0.3, 0.7])\n",
" probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])\n",
" probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])\n",
" a = pyro.sample(\"a\", dist.Categorical(probs_a))\n",
" b = pyro.sample(\"b\", dist.Categorical(probs_b[a]))\n",
" pyro.sample(\"c\", dist.Categorical(probs_c[b]), obs=torch.tensor(0))\n",
"\n",
"with TrackNonReparam():\n",
" model_tr = trace(model).get_trace()\n",
"model_tr.compute_log_prob()\n",
"\n",
"assert get_provenance(model_tr.nodes[\"a\"][\"log_prob\"]) == frozenset({'a'})\n",
"assert get_provenance(model_tr.nodes[\"b\"][\"log_prob\"]) == frozenset({'b', 'a'})\n",
"assert get_provenance(model_tr.nodes[\"c\"][\"log_prob\"]) == frozenset({'b', 'a'})\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reducing Variance with Data-Dependent Baselines\n",
"\n",
"The second strategy for reducing variance in our ELBO gradient estimator goes under the name of baselines (see e.g. reference [6]). It actually makes use of the same bit of math that underlies the variance reduction strategy discussed above, except now instead of removing terms we're going to add terms. Basically, instead of removing terms with zero expectation that tend to _contribute_ to the variance, we're going to add specially chosen terms with zero expectation that work to _reduce_ the variance. As such, this is a control variate strategy.\n",
Expand Down Expand Up @@ -413,9 +453,9 @@
"<br/>&nbsp;&nbsp;&nbsp;&nbsp;\n",
" John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel\n",
" \n",
"[5] `Deep Amortized Inference for Probabilistic Programs`\n",
"[5] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`\n",
"<br/>&nbsp;&nbsp;&nbsp;&nbsp;\n",
"Daniel Ritchie, Paul Horsfall, Noah D. Goodman\n",
"David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind\n",
"\n",
"[6] `Neural Variational Inference and Learning in Belief Networks`\n",
"<br/>&nbsp;&nbsp;&nbsp;&nbsp;\n",
Expand All @@ -425,7 +465,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -439,7 +479,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
"version": "3.8.16"
}
},
"nbformat": 4,
Expand Down

0 comments on commit dd4e0f8

Please sign in to comment.