Skip to content

Commit

Permalink
Handle nan loss in Trace_ELBO when enum_discrete=True
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 24, 2017
1 parent fc71dab commit 8c14f36
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 23 deletions.
20 changes: 20 additions & 0 deletions pyro/infer/trace_elbo.py
@@ -1,6 +1,9 @@
import numbers

import pyro
import pyro.poutine as poutine
from pyro.infer.enum import iter_discrete_traces
from pyro.distributions.util import torch_zeros_like


class Trace_ELBO(object):
Expand Down Expand Up @@ -63,6 +66,13 @@ def loss(self, model, guide, *args, **kwargs):
elbo_particle += model_trace.nodes[name][log_pdf]
elbo_particle -= guide_trace.nodes[name][log_pdf]

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch_zeros_like(elbo_particle)
else:
elbo_particle[weight == 0] = 0.0

elbo += (weight * elbo_particle).data.sum()

loss = -elbo
Expand Down Expand Up @@ -102,6 +112,16 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
surrogate_elbo_particle += model_trace.nodes[name][log_pdf] + \
log_r.detach() * guide_trace.nodes[name][log_pdf]

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
if weight == 0.0:
elbo_particle = torch_zeros_like(elbo_particle)
surrogate_elbo_particle = torch_zeros_like(surrogate_elbo_particle)
else:
weight_eq_zero = (weight == 0)
elbo_particle[weight_eq_zero] = 0.0
surrogate_elbo_particle[weight_eq_zero] = 0.0

elbo += (weight * elbo_particle).data.sum()
surrogate_elbo += (weight * surrogate_elbo_particle).sum()

Expand Down
67 changes: 44 additions & 23 deletions tests/infer/test_enum.py
@@ -1,4 +1,5 @@
import itertools
import math

import pytest
import torch
Expand All @@ -20,31 +21,18 @@
reason="pytorch segfaults at 0.2.0_4, fixed by 0.2.0+f964105")


# A purely discrete model, no batching.
def model0():
p = pyro.param("p", Variable(torch.Tensor([0.05])))
ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
return dict(x=x, y=y)


# A discrete model with batching.
def model1():
p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1]])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
assert x.size() == (2, 1)
assert y.size() == (2, 1)
return dict(x=x, y=y)


@pytest.mark.parametrize("graph_type", ["flat", "dense"])
def test_iter_discrete_traces_scalar(graph_type):
pyro.clear_param_store()
traces = list(iter_discrete_traces(graph_type, model0))

def model():
p = pyro.param("p", Variable(torch.Tensor([0.05])))
ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
return dict(x=x, y=y)

traces = list(iter_discrete_traces(graph_type, model))

p = pyro.param("p").data
ps = pyro.param("ps").data
Expand All @@ -61,7 +49,18 @@ def test_iter_discrete_traces_scalar(graph_type):
@pytest.mark.parametrize("graph_type", ["flat", "dense"])
def test_iter_discrete_traces_vector(graph_type):
pyro.clear_param_store()
traces = list(iter_discrete_traces(graph_type, model1))

def model():
p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1]])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
assert x.size() == (2, 1)
assert y.size() == (2, 1)
return dict(x=x, y=y)

traces = list(iter_discrete_traces(graph_type, model))

p = pyro.param("p").data
ps = pyro.param("ps").data
Expand All @@ -76,6 +75,28 @@ def test_iter_discrete_traces_vector(graph_type):
assert_equal(scale, expected_scale)


@pytest.mark.parametrize("enum_discrete", [True, False], ids=["sum", "sample"])
@pytest.mark.parametrize("trace_graph", [False, True], ids=["dense", "flat"])
def test_iter_discrete_traces_nan(enum_discrete, trace_graph):
pyro.clear_param_store()

def model():
p = Variable(torch.Tensor([0.0, 0.5, 1.0]))
pyro.sample("z", dist.Bernoulli(p))

def guide():
p = pyro.param("p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True))
pyro.sample("z", dist.Bernoulli(p))

Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
elbo = Elbo(enum_discrete=enum_discrete)
with xfail_if_not_implemented():
loss = elbo.loss(model, guide)
assert isinstance(loss, float) and not math.isnan(loss), loss
loss = elbo.loss_and_grads(model, guide)
assert isinstance(loss, float) and not math.isnan(loss), loss


# A simple Gaussian mixture model, with no vectorization.
def gmm_model(data, verbose=False):
p = pyro.param("p", Variable(torch.Tensor([0.3]), requires_grad=True))
Expand Down

0 comments on commit 8c14f36

Please sign in to comment.