## Table of Contents

* How PyroModule works

* How to create a PyroModule

* How effects work

* How to constrain parameters

* How to make a PyroModule Bayesian

* Caution: accessing attributes inside plates

* How to create a complex nested PyroModule

* How naming works

* Caution: avoiding duplicate names

In [2]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')

  from .autonotebook import tqdm as notebook_tqdm


### How PyroModule works

PyroModule aims to combine Pyro’s primitives and effect handlers with PyTorch’s nn.Module idiom, thereby enabling Bayesian treatment of existing nn.Modules and enabling model serving via jit.trace_module. Before you start using PyroModules it will help to understand how they work, so you can avoid pitfalls.

PyroModule is a subclass of nn.Module. PyroModule enables Pyro effects by inserting effect handling logic on module attribute access, overriding the .__getattr__(), .__setattr__(), and .__delattr__() methods. Additionally, because some effects (like sampling) apply only once per model invocation, PyroModule overrides the .__call__() method to ensure samples are generated at most once per .__call__() invocation (note nn.Module subclasses typically implement a .forward() method that is called by .__call__()).

### How to create a PyroModule

There are three ways to create a PyroModule. Let’s start with a nn.Module that is not a PyroModule:

In [3]:
class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight
    
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)


The first way to create a PyroModule is to create a subclass of PyroModule. You can update any nn.Module you’ve written to be a PyroModule, e.g.


```python
- class Linear(nn.Module):
+ class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
          self.bias = ...
      ...
```
Alternatively if you want to use third-party code like the Linear above you can subclass it, using PyroModule as a mixin class

In [4]:
class PyroLinear(Linear, PyroModule):
    pass

linear = PyroLinear(5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The second way to create a PyroModule is to use bracket syntax PyroModule[-] to automatically denote a trivial mixin class as above.

```python
- linear = Linear(5, 2)
+ linear = PyroModule[Linear](5, 2)
```
in our case we can write

In [5]:
linear = PyroModule[Linear](5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

The one difference between manual subclassing and using PyroModule[-] is that PyroModule[-] also ensures all nn.Module superclasses also become PyroModules, which is important for class hierarchies in library code. For example since nn.GRU is a subclass of nn.RNN, also PyroModule[nn.GRU] will be a subclass of PyroModule[nn.RNN].

The third way to create a PyroModule is to change the type of an existing nn.Module instance in-place using to_pyro_module_(). This is useful if you’re using a third-party module factory helper or updating an existing script, e.g.

In [6]:
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

to_pyro_module_(linear)  # this operates in-place
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

### How effects work

So far we’ve created PyroModules but haven’t made use of Pyro effects. But already the nn.Parameter attributes of our PyroModules act like pyro.param statements: they synchronize with Pyro’s param store, and they can be recorded in traces.

In [8]:
pyro.clear_param_store()

# This is not traced:
linear = Linear(5, 2)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

# Now this is traced:
to_pyro_module_(linear)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

Linear
[]
[]
PyroLinear
['bias', 'weight']
['bias', 'weight']


### How to constrain parameters
Pyro parameters allow constraints, and often we want our nn.Module parameters to obey constraints. You can constrain a PyroModule’s parameters by replacing nn.Parameter with a PyroParam attribute. For example to ensure the .bias attribute is positive, we can set it to

In [9]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)
print("params after:", [name for name, _ in linear.named_parameters()])
print("bias:", linear.bias)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

params before: ['weight', 'bias']
params after: ['weight', 'bias_unconstrained']
bias: tensor([0.6404, 5.2888], grad_fn=<AddBackward0>)


Now PyTorch will optimize the .bias_unconstrained parameter, and each time we access the .bias attribute it will read and transform the .bias_unconstrained parameter (similar to a Python @property).

If you know the constraint beforehand, you can build it into the module constructor, e.g.

```python
  class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
-         self.bias = nn.Parameter(torch.randn(out_size))
+         self.bias = PyroParam(torch.randn(out_size).exp(),
+                               constraint=constraints.positive)
      ...

```

### How to make a PyroModule Bayesian
So far our Linear module is still deterministic. To make it randomized and Bayesian, we’ll replace nn.Parameter and PyroParam attributes with PyroSample attributes, specifying a prior. Let’s put a simple prior over the weights, taking care to expand its shape to [5,2] and declare event dimensions with .to_event() (as explained in the tensor shapes tutorial).

In [10]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.weight = PyroSample(dist.Normal(0, 1).expand([5, 2]).to_event(2))
print("params after:", [name for name, _ in linear.named_parameters()])
print("weight:", linear.weight)
print("weight:", linear.weight)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

params before: ['weight', 'bias_unconstrained']
params after: ['bias_unconstrained']
weight: tensor([[ 1.0344,  1.7259],
        [-0.9617, -0.0387],
        [ 1.3176,  0.3925],
        [ 1.2339,  1.1451],
        [-0.8739,  0.0127]])
weight: tensor([[-0.8188, -0.7677],
        [ 0.0810, -1.2442],
        [ 0.4361,  0.5980],
        [-0.5942,  0.2333],
        [ 0.4862, -0.2085]])


* Notice that the .weight parameter now disappears, and each time we call linear() a new weight is sampled from the prior. In fact, the weight is sampled when the Linear.forward() accesses the .weight attribute: this attribute now has the special behavior of sampling from the prior.

We can see all the Pyro effects that appear in the trace:

In [13]:
with poutine.trace() as tr:
    linear(example_input)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"])

param bias tensor([0.6404, 5.2888], grad_fn=<AddBackward0>)
sample weight tensor([[ 1.8861, -1.9781],
        [ 1.6661,  1.0092],
        [-0.6315,  0.1163],
        [ 0.6082,  1.3411],
        [ 0.0353, -1.5681]])


So far we’ve modified a third-party module to be Bayesian

```
linear = Linear(...)
to_pyro_module_(linear)
linear.bias = PyroParam(...)
linear.weight = PyroSample(...)
```