https://pyro.ai/examples/intro_part_i.html

In [1]:
import torch
import pyro

In [2]:
pyro.set_rng_seed(101)

### Primitive Stochastic Function (in torch)

In [3]:
loc = 0. # mean zero
scale = 1. # unit variance

# create a normal distribution
normal = torch.distributions.Normal(loc, scale)

x = normal.rsample()

In [8]:
print('sample', x)

sample tensor(-1.3905)


In [10]:
x.item()

-1.3905061483383179

In [9]:
print('log prob', normal.log_prob(x))

log prob tensor(-1.8857)


### A Simple Model (in torch)

inference cloudy or sunny

In [24]:
# torch version
def weather_torch():
    cloudy = torch.distributions.Bernoulli(0.3).sample()
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    
    temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
    return cloudy, temp.item()

In [23]:
for _ in range(3):
    print(weather_torch())

('cloudy', 38.191741943359375)
('sunny', 76.34708404541016)
('sunny', 67.69495391845703)


### The pyro.sample primitive

In [30]:
x = pyro.sample('my_sample', pyro.distributions.Normal(loc, scale))
print(x)

tensor(-0.5929)


In [31]:
def weather_pyro():
    cloudy = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    
    temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
    return cloudy, temp.item()

In [32]:
for _ in range(3):
    print(weather_pyro())

('sunny', 74.95360565185547)
('sunny', 92.0722427368164)
('cloudy', 61.8774528503418)


### Universality: Stochastic Recursion, Higher-order Stochastic Functions, and Random Control Flow

In [33]:
def ice_cream_sales():
    cloudly, temp = weather_pyro()
    expected_sales = 200. if cloudly == 'sunny' and temp > 80.0 else 50.
    
    ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0))
    return ice_cream

In [51]:
ice_cream_sales()

tensor(194.1904)

In [36]:
def geometric(p, t=None):
    if t is None:
        t = 0
    
    x = pyro.sample(f'x_{t}', pyro.distributions.Bernoulli(p))
    if x.item() == 1:
        return 0
    else:
        return 1 + geometric(p, t+1)

In [39]:
print(geometric(0.5))

0


In [41]:
def normal_product(loc, scale):
    z1 = pyro.sample('z1', pyro.distributions.Normal(loc, scale))
    z2 = pyro.sample('z2', pyro.distributions.Normal(loc, scale))
    
    y = z1 * z2
    return y

def make_normal_normal():
    mu_latent = pyro.sample('mu_latent', pyro.distributions.Normal(0, 1))
    fn = lambda scale: normal_product(mu_latent, scale)
    
    return fn

In [46]:
print(make_normal_normal()(1.0))

tensor(0.7510)
