Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Lightning example #3189

Merged
merged 16 commits into from
Mar 16, 2023
Merged

PyTorch Lightning example #3189

merged 16 commits into from
Mar 16, 2023

Conversation

ordabayevy
Copy link
Member

This example shows how to train Pyro models using PyTorch Lightning and is adapted from Horovod example.

@ordabayevy
Copy link
Member Author

Addresses #3171.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I've been using Lightning recently as well, so I left some (optional) suggestions aimed at making the example slightly more PyTorch-idiomatic using new features from #3149


def main(args):
# Create a model, synthetic data, a guide, and a lightning module.
pyro.set_rng_seed(args.seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option added in #3149 ensures that parameters of PyroModules will not be implicitly shared across model instances via the Pyro parameter store:

Suggested change
pyro.set_rng_seed(args.seed)
pyro.set_rng_seed(args.seed)
pyro.settings.set(module_local_params=True)

It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as PyroModules and trained using generic PyTorch infrastructure like torch.optim and PyTorch Lightning.

Comment on lines 79 to 80
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change uses the new __call__ method added to the base pyro.infer.elbo.ELBO in #3149 that takes a model and guide returns a torch.nn.Module wrapper around the loss:

Suggested change
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
guide = AutoNormal(model)
loss_fn = Trace_ELBO()(model, guide)
training_plan = PyroLightningModule(loss_fn, args.learning_rate)

It saves you from having to pass around a model and guide everywhere or deal with the Pyro parameter store, which makes SVI a little easier to use with other PyTorch tools like Lightning and the PyTorch JIT.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about ELBOModule. This is much neater!

Comment on lines 86 to 90
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we warm up the guide by running one mini-batch through it.
mini_batch = dataset[: args.batch_size]
guide(*mini_batch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we warm up the guide by running one mini-batch through it.
mini_batch = dataset[: args.batch_size]
guide(*mini_batch)
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

Comment on lines +4 to +7
# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the distributed training in this example? Is it hidden in the default configuration of the DataLoader and TrainingPlan in main below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argparse arguments are passed to the pl.Trainer:

trainer = pl.Trainer.from_argparse_args(args)

So you can run the script as follows:

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

When there are multiple devices DataLoader will use DistributedSampler automatically.

Comment on lines 54 to 58
def __init__(self, model, guide, lr):
super().__init__()
self.pyro_model = model
self.pyro_guide = guide
self.loss_fn = Trace_ELBO().differentiable_loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, model, guide, lr):
super().__init__()
self.pyro_model = model
self.pyro_guide = guide
self.loss_fn = Trace_ELBO().differentiable_loss
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
super().__init__()
self.loss_fn = loss_fn
self.model = loss_fn.model
self.guide = loss_fn.guide


def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch)
loss = self.loss_fn(*batch)


def configure_optimizers(self):
"""Configure an optimizer."""
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr)
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)

Comment on lines 59 to 60
self.lr = lr

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a forward method that calls Predictive is sometimes helpful:

Suggested change
self.lr = lr
self.lr = lr
self.predictive = pyro.infer.Predictive(self.model, guide=self.guide)
def forward(self, *args):
return self.predictive(*args)

Copy link
Member Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing @eb8680. I think it is much neater now using ELBOModule!

Comment on lines +4 to +7
# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argparse arguments are passed to the pl.Trainer:

trainer = pl.Trainer.from_argparse_args(args)

So you can run the script as follows:

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

When there are multiple devices DataLoader will use DistributedSampler automatically.

Comment on lines 79 to 80
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about ELBOModule. This is much neater!

fritzo
fritzo previously approved these changes Mar 14, 2023
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Can you just confirm the generated docs are readable, i.e. after running make tutorial? Also ensure the title isn't too long when it appears on the left hand side TOC.

@ordabayevy
Copy link
Member Author

@fritzo There is something wrong with building tutorials when I run make tutorial:

make tutorial
make -C tutorial html
make[1]: Entering directory '/mnt/disks/dev/repos/pyro/tutorial'
Running Sphinx v6.1.3
building [mo]: targets for 0 po files that are out of date
writing output... 
building [html]: targets for 80 source files that are out of date
updating environment: [new config] 80 added, 0 changed, 0 removed
reading sources... [100%] svi_part_ii .. working_memory                                                                                                             

Warning, treated as error:
/mnt/disks/dev/repos/pyro/tutorial/source/gp.ipynb:973:Duplicate substitution definition name: "image0".
make[1]: *** [Makefile:20: html] Error 2
make[1]: Leaving directory '/mnt/disks/dev/repos/pyro/tutorial'
make: *** [Makefile:18: tutorial] Error 2

Trying to figure out what is wrong ... (if you know a quick fix would appreciate it)

@fritzo
Copy link
Member

fritzo commented Mar 15, 2023

@ordabayevy not sure what's causing the build issue...

Unrelated, I see

.../pyro-ppl/pyro/tutorial/source/svi_lightning.rst: WARNING: document isn't included in any toctree

Could you add svi_lightning to tutorial/source/index.rst so it shows up on the website?

@ordabayevy
Copy link
Member Author

Still no luck with make tutorial. When I try to build tutorials on dev branch I get this:

make -C tutorial html
make[1]: Entering directory '/mnt/disks/dev/repos/pyro/tutorial'
Running Sphinx v6.1.3
making output directory... done
building [mo]: targets for 0 po files that are out of date
writing output... 
building [html]: targets for 79 source files that are out of date
updating environment: [new config] 79 added, 0 changed, 0 removed
reading sources... [100%] tensor_shapes .. working_memory                                                                                                           

Warning, treated as error:
/mnt/disks/dev/repos/pyro/tutorial/source/logistic-growth.ipynb:1220:File not found: 'workflow.html'
make[1]: *** [Makefile:20: html] Error 2
make[1]: Leaving directory '/mnt/disks/dev/repos/pyro/tutorial'
make: *** [Makefile:18: tutorial] Error 2

@ordabayevy
Copy link
Member Author

Can you just confirm the generated docs are readable, i.e. after running make tutorial? Also ensure the title isn't too long when it appears on the left hand side TOC.

I was able to build the tutorial by ignoring warnings and can confirm that the generated doc is readable and the title in the left hand side TOC is not too long.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for building tutorials. I'll look into fixing those warnings.

@eb8680 any further comments? I'll hold off merging, feel free to merge

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit c6851b8 into dev Mar 16, 2023
@eb8680 eb8680 deleted the svi-lightning branch March 16, 2023 03:43
@ordabayevy
Copy link
Member Author

Thanks @eb8680 and @fritzo for reviewing!

luisdiaz1997 added a commit to luisdiaz1997/pyro that referenced this pull request Mar 16, 2023
luisdiaz1997 added a commit to luisdiaz1997/pyro that referenced this pull request Mar 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants