Skip to content

Commit

Permalink
Custom loss documentation (#3122)
Browse files Browse the repository at this point in the history
* Improving custom objectives example

* Rename custom_objectives.ipynb to custom_objectives_training.ipynb

* Minuscule changes as discussed
  • Loading branch information
e-pet committed Jul 31, 2022
1 parent 0e0b1ed commit bca60c9
Showing 1 changed file with 23 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Custom SVI Objectives\n",
"# Customizing SVI objectives and training loops\n",
"\n",
"Pyro provides support for various optimization-based approaches to Bayesian inference, with `Trace_ELBO` serving as the basic implementation of SVI (stochastic variational inference).\n",
"See the [docs](http://docs.pyro.ai/en/dev/inference_algos.html#module-pyro.infer.svi) for more information on the various SVI implementations and SVI \n",
Expand All @@ -13,7 +13,7 @@
"and [III](http://pyro.ai/examples/svi_part_iii.html) for background on SVI.\n",
"\n",
"In this tutorial we show how advanced users can modify and/or augment the variational\n",
"objectives (alternatively: loss functions) provided by Pyro to support special use cases."
"objectives (alternatively: loss functions) and the training step implementation provided by Pyro to support special use cases."
]
},
{
Expand Down Expand Up @@ -57,19 +57,28 @@
"- `SVI.step()` zeros gradients between gradient steps\n",
"\n",
"If we want more control, we can directly manipulate the differentiable loss method of \n",
"the various `ELBO` classes. For example, (assuming we know all the parameters in advance) \n",
"this is equivalent to the previous code snippet:\n",
"the various `ELBO` classes. For example, this optimization loop:\n",
"```python\n",
"svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())\n",
"for i in range(n_iter):\n",
" loss = svi.step(X_train, y_train)\n",
"```\n",
"is equivalent to this low-level pattern:\n",
"\n",
"```python\n",
"# define optimizer and loss function\n",
"optimizer = torch.optim.Adam(my_parameters, {\"lr\": 0.001, \"betas\": (0.90, 0.999)})\n",
"loss_fn = pyro.infer.Trace_ELBO().differentiable_loss\n",
"# compute loss\n",
"loss = loss_fn(model, guide, model_and_guide_args)\n",
"loss.backward()\n",
"# take a step and zero the parameter gradients\n",
"optimizer.step()\n",
"optimizer.zero_grad()\n",
"loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)\n",
"with pyro.poutine.trace(param_only=True) as param_capture:\n",
" loss = loss_fn(model, guide)\n",
"params = set(site[\"value\"].unconstrained()\n",
" for site in param_capture.trace.nodes.values())\n",
"optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))\n",
"for i in range(n_iter):\n",
" # compute loss\n",
" loss = loss_fn(model, guide)\n",
" loss.backward()\n",
" # take a step and zero the parameter gradients\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"```\n",
"\n",
"## Example: Custom Regularizer\n",
Expand Down Expand Up @@ -106,6 +115,7 @@
"+ optimizer = pyro.optim.Adam({\"lr\": 0.001, \"betas\": (0.90, 0.999)}, {\"clip_norm\": 10.0})\n",
"```\n",
"\n",
"Further variants of gradient clipping can also be implemented manually by modifying the low-level pattern described above.\n",
"\n",
"## Example: Scaling the Loss\n",
"\n",
Expand Down

0 comments on commit bca60c9

Please sign in to comment.