## Table of Contents

* How PyroModule works

* How to create a PyroModule

* How effects work

* How to constrain parameters

* How to make a PyroModule Bayesian

* Caution: accessing attributes inside plates

* How to create a complex nested PyroModule

* How naming works

* Caution: avoiding duplicate names

In [1]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')

  from .autonotebook import tqdm as notebook_tqdm


### How PyroModule works

PyroModule aims to combine Pyro’s primitives and effect handlers with PyTorch’s nn.Module idiom, thereby enabling Bayesian treatment of existing nn.Modules and enabling model serving via jit.trace_module. Before you start using PyroModules it will help to understand how they work, so you can avoid pitfalls.

PyroModule is a subclass of nn.Module. PyroModule enables Pyro effects by inserting effect handling logic on module attribute access, overriding the .__getattr__(), .__setattr__(), and .__delattr__() methods. Additionally, because some effects (like sampling) apply only once per model invocation, PyroModule overrides the .__call__() method to ensure samples are generated at most once per .__call__() invocation (note nn.Module subclasses typically implement a .forward() method that is called by .__call__()).

### How to create a PyroModule

There are three ways to create a PyroModule. Let’s start with a nn.Module that is not a PyroModule:

In [5]:
class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight
    
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)


The first way to create a PyroModule is to create a subclass of PyroModule. You can update any nn.Module you’ve written to be a PyroModule, e.g.


```python
- class Linear(nn.Module):
+ class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
          self.bias = ...
      ...
```
Alternatively if you want to use third-party code like the Linear above you can subclass it, using PyroModule as a mixin class

In [6]:
class PyroLinear(Linear, PyroModule):
    pass

linear = PyroLinear(5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The second way to create a PyroModule is to use bracket syntax PyroModule[-] to automatically denote a trivial mixin class as above.

```python
- linear = Linear(5, 2)
+ linear = PyroModule[Linear](5, 2)
```
in our case we can write

In [7]:
linear = PyroModule[Linear](5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The one difference between manual subclassing and using PyroModule[-] is that PyroModule[-] also ensures all nn.Module superclasses also become PyroModules, which is important for class hierarchies in library code. For example since nn.GRU is a subclass of nn.RNN, also PyroModule[nn.GRU] will be a subclass of PyroModule[nn.RNN].

The third way to create a PyroModule is to change the type of an existing nn.Module instance in-place using to_pyro_module_(). This is useful if you’re using a third-party module factory helper or updating an existing script, e.g.

In [8]:
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

to_pyro_module_(linear)  # this operates in-place
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)