In [None]:
%matplotlib inline
# import some dependencies
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable

import pyro
import pyro.distributions as dist

# Models in Pyro: from primitive distributions to stochastic functions

The basic unit of Pyro programs is the stochastic function, or an arbitrary Python callable (i.e. Python object with a `__call__()` method, like a function, a method, or a `nn.Module`) that either contains some irreducible internal randomness or calls other stochastic functions internally with the top-level function `pyro.sample`.

Throughout this tutorial and the rest of Pyro's tutorials and documentation, we will often call stochastic functions *models*, because stochastic functions are often meant as simplified or stylized descriptions of a process by which data is generated.  Expresing models as stochastic functions in Pyro means that models may be composed, reused, imported, and serialized just like regular Python callables.

Conceptually, `pyro.sample` simply applies its function argument to the rest of its arguments, much like
```python
apply_fn = lambda fn, *args, **kwargs: fn(*args, **kwargs)
```
However, its first argument is always a string that we call a name.  Each name can only appear once in a stochastic function; Pyro's backend uses these names to uniquely idenfity `sample` statements and change their behavior at runtime.

In [None]:
def normal_product(mu, sigma):
    z1 = pyro.sample("z1", dist.diagnormal, mu, sigma)
    z2 = pyro.sample("z2", dist.diagnormal, torch.zeros(mu.size()), torch.ones(sigma.size()))
    y = z1 * z2
    return y

print(normal_product(torch.zeros(1), torch.ones(1)))

Stochastic functions induce joint probability distributions `p(y, z | x)` over their outputs `y` and internal stochastic function calls `z` given their inputs `x`.  In general, sampling from the marginal distribution `p(y | x)` or computing marginal probabilities is intractable.

## Primitive stochastic functions

Primitive stochastic functions, or distributions, are an important class of stochastic functions for which we can explicitly compute the probability of outputs given inputs.  In `normal_product`, the function `diagnormal` is a primitive stochastic function.  Pyro includes a rich standalone library of GPU-accelerated multivariate probability distributions built on PyTorch (LINK).

Users can also implement custom distributions by simply subclassing `pyro.distributions.Distribution`.

In [None]:
# samples are drawn by calling the primitive distribution with any necessary parameters
x = dist.diagnormal(torch.ones(1)*0.5)
print(x)

# primitive stochastic functions can also score values by calling 
# distribution.log_pdf(value_to_score, *arguments_to_distribution_sample)
logpx = dist.bernoulli.log_pdf(x, p)
print(logpx)

## Stochastic recursion and higher-order stochastic functions

Because Pyro is embedded in Python, stochastic functions can contain arbitrarily complex deterministic Python and randomness can freely affect control flow, including generating names dynamically.  For example, we can construct recursive functions that terminate their recursion nondeterministically, so long as all names are unique:

In [None]:
def geometric(p, t=None):
    if t is None:
        t = 0
    x = pyro.sample("x_{}".format(t), bernoulli, p)
    if torch.equal(x, torch.zeros(x.size())):
        return x
    else:
        return x + geometric(t+1, p)
    
print(geometric(torch.Tensor([0.5])))

We are also free to define stochastic functions that accept as input or produce as output other stochastic functions:

In [None]:
def make_normal_normal():
    mu_latent = pyro.sample("mu_latent", dist.diagnormal, torch.zeros(1), torch.ones(1))
    return lambda: normal_product(mu_latent, torch.ones(1))

normal_product_2 = make_normal_product()
print(normal_product_2())

## Implicit distributions and external randomness

Stochastic functions in Pyro can in principle contain arbitrary external randomness, though current inference algorithms expect all `pyro.sample` statements to contain primitive stochastic functions.

In [None]:
def implicit_normal(mu, sigma):
    return torch.randn(mu.size()) * sigma + mu

Internal stochastic function calls should not depend on top-level external randomness:

In [None]:
# ALLOWED
def implicit_normal_product(mu):
    a = pyro.sample("a", diagnormal, mu, torch.ones(1))
    b = torch.randn(mu.size())
    return a * b

# NOT ALLOWED
def implicit_normal_product_2(mu):
    sigma = torch.exp(torch.randn(mu.size()))
    a = pyro.sample("a", dist.diagnormal, mu, sigma)
    b = torch.randn(mu.size())
    return a * b