Skip to content

Commit

Permalink
PyTorch Lightning example (#3189)
Browse files Browse the repository at this point in the history
* Bump to version 1.5.2 (#2755)

* PyTorch Lightning example

* fixes

* fix test

* update comments

* fix pip install pyro-ppl

* address comments

* add svi_lightning to toctree

---------

Co-authored-by: Fritz Obermeyer <fritz.obermeyer@gmail.com>
  • Loading branch information
ordabayevy and fritzo committed Mar 16, 2023
1 parent 9afb089 commit c6851b8
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/svi_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# https://horovod.readthedocs.io/en/stable
#
# This assumes you have installed horovod, e.g. via
# pip install pyro[horovod]
# pip install pyro-ppl[horovod]
# For detailed instructions see
# https://horovod.readthedocs.io/en/stable/install.html
# On my mac laptop I was able to install horovod with
Expand Down
116 changes: 116 additions & 0 deletions examples/svi_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
# library. PyTorch Lightning enables data-parallel training by aggregating stochastic
# gradients at each step of training. We focus on integration between PyTorch Lightning and Pyro.
# For further details on distributed computing with PyTorch Lightning, see
# https://lightning.ai/docs/pytorch/latest
#
# This assumes you have installed pytorch lightning, e.g. via
# pip install pyro-ppl[lightning]

import argparse

import pytorch_lightning as pl
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule


# We define a model as usual, with no reference to Pytorch Lightning.
# This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size

def forward(self, covariates, data=None):
coeff = pyro.sample("coeff", dist.Normal(0, 1))
bias = pyro.sample("bias", dist.Normal(0, 1))
scale = pyro.sample("scale", dist.LogNormal(0, 1))

# Since we'll use a distributed dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size. In particular we cannot rely on pyro.plate to
# automatically subsample, since that would lead to all workers drawing
# identical subsamples.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale), obs=data)


# We define an ELBO loss, a PyTorch optimizer, and a training step in our PyroLightningModule.
# Note that we are using a PyTorch optimizer instead of a Pyro optimizer and
# we are using ``training_step`` instead of Pyro's SVI machinery.
class PyroLightningModule(pl.LightningModule):
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
super().__init__()
self.loss_fn = loss_fn
self.model = loss_fn.model
self.guide = loss_fn.guide
self.lr = lr
self.predictive = pyro.infer.Predictive(
self.model, guide=self.guide, num_samples=1
)

def forward(self, *args):
return self.predictive(*args)

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
loss = self.loss_fn(*batch)
# Logging to TensorBoard by default
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
"""Configure an optimizer."""
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)


def main(args):
# Create a model, synthetic data, a guide, and a lightning module.
pyro.set_rng_seed(args.seed)
pyro.settings.set(module_local_params=True)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
loss_fn = Trace_ELBO()(model, guide)
training_plan = PyroLightningModule(loss_fn, args.learning_rate)

# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size)

# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

# Run stochastic variational inference using PyTorch Lightning Trainer.
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(training_plan, train_dataloaders=dataloader)


if __name__ == "__main__":
assert pyro.__version__.startswith("1.8.4")
parser = argparse.ArgumentParser(
description="Distributed training via PyTorch Lightning"
)
parser.add_argument("--size", default=1000000, type=int)
parser.add_argument("--batch_size", default=100, type=int)
parser.add_argument("--learning_rate", default=0.01, type=float)
parser.add_argument("--seed", default=20200723, type=int)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
"yapf",
],
"horovod": ["horovod[pytorch]>=0.19"],
"lightning": ["pytorch_lightning"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
Expand Down
8 changes: 8 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def wrapper(*args, **kwargs):
horovod is None, reason="horovod is not available"
)

try:
import pytorch_lightning
except ImportError:
pytorch_lightning = None
requires_lightning = pytest.mark.skipif(
pytorch_lightning is None, reason="pytorch lightning is not available"
)

try:
import funsor
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
requires_cuda,
requires_funsor,
requires_horovod,
requires_lightning,
xfail_param,
)

Expand Down Expand Up @@ -110,6 +111,10 @@
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto",
"sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide easy",
"svi_horovod.py --num-epochs=2 --size=400 --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator cpu --devices 1",
marks=[requires_lightning],
),
"toy_mixture_model_discrete_enumeration.py --num-steps=1",
"sparse_regression.py --num-steps=100 --num-data=100 --num-dimensions 11",
"vae/ss_vae_M2.py --num-epochs=1",
Expand Down Expand Up @@ -177,6 +182,10 @@
"sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda",
"sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda",
"svi_horovod.py --num-epochs=2 --size=400 --cuda --no-horovod",
pytest.param(
"svi_lightning.py --max_epochs=2 --size=400 --accelerator gpu --devices 1",
marks=[requires_lightning],
),
"vae/vae.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --cuda",
"vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda",
Expand Down
1 change: 1 addition & 0 deletions tutorial/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ List of Tutorials
prior_predictive
jit
svi_horovod
svi_lightning

.. toctree::
:maxdepth: 1
Expand Down
17 changes: 17 additions & 0 deletions tutorial/source/svi_lightning.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Example: distributed training via PyTorch Lightning
===================================================

This script passes argparse arguments to PyTorch Lightning ``Trainer`` automatically_, for example::

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

.. _automatically: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-in-python-scripts

`View svi_lightning.py on github`__

.. _github: https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_lightning.py

__ github_

.. literalinclude:: ../../examples/svi_lightning.py
:language: python

0 comments on commit c6851b8

Please sign in to comment.