Skip to content

Commit

Permalink
Add a simple PyTorch training example (#3327)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Feb 15, 2024
1 parent b696ccf commit 98441c3
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 1 deletion.
111 changes: 111 additions & 0 deletions examples/svi_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Using vanilla PyTorch to perform optimization in SVI.
#
# This tutorial demonstrates how to use standard PyTorch optimizers, dataloaders and training loops
# to perform optimization in SVI. This is useful when you want to use custom optimizers,
# learning rate schedules, dataloaders, or other advanced training techniques,
# or just to simplify integration with other elements of the PyTorch ecosystem.

import argparse
from typing import Callable

import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule


# We define a model as usual. This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size
# We register a buffer for a constant scalar tensor to represent zero.
# This is useful for making priors that do not depend on inputs
# or learnable parameters compatible with the Module.to() method
# for setting the device or dtype of a module and its parameters.
self.register_buffer("zero", torch.tensor(0.0))

def forward(self, covariates, data=None):
# Sample parameters from priors that make use of the zero buffer trick
coeff = pyro.sample("coeff", dist.Normal(self.zero, 1))
bias = pyro.sample("bias", dist.Normal(self.zero, 1))
scale = pyro.sample("scale", dist.LogNormal(self.zero, 1))

# Since we'll use a PyTorch dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size, rather than relying on pyro.plate to automatically subsample.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale), obs=data)


def main(args):
# Make PyroModule parameters local (like ordinary torch.nn.Parameters),
# rather than shared by name through Pyro's global parameter store.
# This is highly recommended whenever models can be written without pyro.param().
pyro.settings.set(module_local_params=True)

# set seed for reproducibility
pyro.set_rng_seed(args.seed)

# Create a synthetic dataset from a randomly initialized model.
with torch.no_grad():
covariates = torch.randn(args.size)
data = Model(args.size)(covariates)
covariates = covariates.to(device=torch.device("cuda" if args.cuda else "cpu"))
data = data.to(device=torch.device("cuda" if args.cuda else "cpu"))

# Create a model and a guide, both as (Pyro)Modules.
model: torch.nn.Module = Model(args.size)
guide: torch.nn.Module = AutoNormal(model)

# Create a loss function as a Module that includes model and guide parameters.
# All Pyro ELBO estimators can be __call__()ed with a model and guide pair as arguments
# to return a loss function Module that takes the same arguments as the model and guide
# and exposes all of their torch.nn.Parameters and pyro.nn.PyroParam parameters.
elbo: Callable[[torch.nn.Module, torch.nn.Module], torch.nn.Module] = Trace_ELBO()
loss_fn: torch.nn.Module = elbo(model, guide)
loss_fn.to(device=torch.device("cuda" if args.cuda else "cpu"))

# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

# All relevant parameters need to be initialized before an optimizer can be created.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

# Create a PyTorch optimizer for the parameters of the model and guide in loss_fn.
optimizer = torch.optim.Adam(loss_fn.parameters(), lr=args.learning_rate)

# Run stochastic variational inference using PyTorch optimizers from torch.optim
for epoch in range(args.num_epochs):
for batch in dataloader:
optimizer.zero_grad()
loss = loss_fn(*batch)
loss.backward()
optimizer.step()
print(f"epoch {epoch} loss = {loss}")


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.6")
parser = argparse.ArgumentParser(
description="Using vanilla PyTorch to perform optimization in SVI"
)
parser.add_argument("--size", default=10000, type=int)
parser.add_argument("--batch-size", default=100, type=int)
parser.add_argument("--learning-rate", default=0.01, type=float)
parser.add_argument("--seed", default=20200723, type=int)
parser.add_argument("--num-epochs", default=10, type=int)
parser.add_argument("--cuda", action="store_true", default=False)
args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom",
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto",
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy",
"svi_torch.py --num-epochs=2 --size=400",
"svi_horovod.py --num-epochs=2 --size=400 --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1",
Expand Down Expand Up @@ -181,6 +182,7 @@
"sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda",
"sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda",
"sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda",
"svi_torch.py --num-epochs=2 --size=400 --cuda",
"svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1",
Expand Down
4 changes: 3 additions & 1 deletion tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ and look carefully through the series :ref:`practical-pyro-and-pytorch`,
especially the :doc:`first Bayesian regression tutorial <bayesian_regression>`.
This tutorial goes step-by-step through solving a simple Bayesian machine learning problem with Pyro,
grounding the concepts from the introductory tutorials in runnable code.
Industry users interested in serving predictions from a trained model in C++ should also read :doc:`the PyroModule tutorial <modules>`.
Users interested in integrating with existing PyTorch training and serving infrastructure should also read :doc:`the PyroModule tutorial <modules>`
and look at the :doc:`SVI with PyTorch <svi_torch>` and :doc:`SVI with Lightning <svi_lightning>` examples.

Most users who reach this point will also find our :doc:`guide to tensor shapes in Pyro <tensor_shapes>` essential reading.
Pyro makes extensive use of the behavior of `"array broadcasting" <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_
Expand Down Expand Up @@ -95,6 +96,7 @@ List of Tutorials
workflow
prior_predictive
jit
svi_torch
svi_horovod
svi_lightning
svi_flow_guide
Expand Down
15 changes: 15 additions & 0 deletions tutorial/source/svi_torch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Example: using vanilla PyTorch to perform optimization in SVI
=============================================================

This script uses argparse arguments to construct PyTorch optimizer and dataloader, for example::

$ python examples/svi_torch.py --size 10000 --batch-size 100 --num-epochs 100

`View svi_torch.py on github`__

.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_torch.py

__ github_

.. literalinclude:: ../../examples/svi_torch.py
:language: python

0 comments on commit 98441c3

Please sign in to comment.