Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Domain Error in arguments #875

Closed
lazypanda1 opened this issue Mar 10, 2018 · 3 comments
Closed

Domain Error in arguments #875

lazypanda1 opened this issue Mar 10, 2018 · 3 comments

Comments

@lazypanda1
Copy link

Running the following program causes the following exception in pyro

import pyro, numpy as np, torch, pyro.distributions as dist, torch.nn as nn
from pyro.optim import Adam
from pyro.infer import SVI
from torch.autograd import Variable

class RegressionModel(nn.Module):

    def __init__(self, p):
        super(RegressionModel, self).__init__()
        self.linear = nn.Linear(p, 1)

    def forward(self, x):
        return self.linear(x)
regression_model = RegressionModel(1)
datax = np.array([66.51273632392986, 72.48109160982891, 51.593255272579405, 82.37396951614642, 24.351195093810894, 28.96641791103359, 99.38220110202795, 46.945634876575184, 74.85506975121858, 23.871686140356985, 42.81192218973854, 0.16365441832948413, 67.72794093186648, 1.5779302410995455, 85.61459094106726, 73.20330969160857, 97.22397591828582, 14.313599505994778, 77.26991866013726, 64.14754051119579, 93.5303981879472, 20.089351299968662, 85.04528490029601, 67.33599800099395, 86.68517663935083, 73.75596550033319, 44.8104798197569, 78.20407642948626, 26.66574964878745, 78.1633108272725, 58.678289830177086, 27.31380907315487, 4.02646194795423, 63.96186643228191, 49.082637268825046, 35.03886967734062, 58.61818429732906, 19.38738927550142, 98.52456832035858, 39.25739771133928, 60.79774433190946, 11.752239158453627, 57.486221669801516, 42.48732146320621, 92.4222761681774, 1.4253718046661978, 98.75835953068474, 37.249529584159546, 97.97336302377285, 94.2427128026267, 6.961615660785558, 32.09459629834591, 84.06976500181592, 19.70963412605615, 73.53560318406385, 3.5348040443181183, 63.37678085272886, 17.589970652899012, 23.756901561350062, 48.118915513262436, 13.310386313121093, 93.84208592206876, 42.0106423078408, 0.7016460089032894, 36.249000882381935, 48.097434185085696, 78.76424951716339, 45.02665818712342, 63.695720937909684, 57.593397718229724, 27.644987337721016, 35.2590623392631, 35.32549159337509, 54.3525304361895, 14.813294901937669, 85.34376122983905, 33.38578237754307, 13.190550838966153, 88.38295542333219, 1.4925274443232994, 13.659244846866613, 36.37151708656099, 10.993733689893459, 74.58947258034591, 30.702564887951855, 63.837945840218126, 32.7003556825763, 20.54693155899664, 10.189507990750712, 45.14291762207758, 93.83944291196097, 62.00878790146631, 31.72264988113014, 39.34175954482607, 39.484387340142455, 65.16948881524183, 0.29697889466003824, 73.20371542401847, 50.1177693144412, 89.14416693261028])
datay = np.array([7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0])
datax = datax.reshape((100, 1))
datay = datay.reshape((100, 1))
datax = Variable(torch.Tensor(datax))
datay = Variable(torch.Tensor(datay))
torch.manual_seed(100)
np.random.seed(100)

def model(datax, datay):
    p1 = dist.Beta(Variable((21.217094954047578 * torch.ones(1, 1))), Variable((86.18401153488547 * torch.ones(1, 1))))
    p2 = dist.Gamma(Variable((23.809913471191347 * torch.ones(1))), Variable((2.344758907621358 * torch.ones(1))))
    priors =     {'linear.weight': p1, 'linear.bias': p2}
    lifted_module =     pyro.random_module('module', regression_model, priors)
    lifted_reg =     lifted_module()
    prediction_mean =     lifted_reg(datax).squeeze(1)
    pyro.sample('obs', dist.Exponential(prediction_mean), obs=datay.squeeze())

def guide(datax, datay):
    w_mu = Variable(torch.randn(1, 1), requires_grad=True)
    w_log_sig = Variable(torch.randn(1, 1), requires_grad=True)
    b_mu = Variable(torch.randn(1), requires_grad=True)
    b_log_sig = Variable(torch.randn(1), requires_grad=True)
    w_dist = dist.Gamma( torch.nn.Softplus()(pyro.param('guide_mean_weight', w_mu)), torch.nn.Softplus()(pyro.param('guide_log_sigma_weight', w_log_sig)))
    b_dist = dist.Normal(pyro.param('guide_mean_bias', b_mu), torch.nn.Softplus()(pyro.param('guide_log_sigma_bias', b_log_sig)))
    dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
    lifted_module =     pyro.random_module('module', regression_model, dists)
    return lifted_module()
optim = Adam({'lr': 0.01})
svi = SVI(model, guide, optim, loss='ELBO')
for i in range(2000):
    loss = svi.step(datax, datay)
    if ((i % 100) == 0):
        print((loss / float(len(datax))))
for name in pyro.get_param_store().get_all_param_names():
    print(('[%s]: %.3f' % (name, pyro.param(name).data.numpy())))

Output:

$ python3.6 /home/pyro_prog.py
/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/trace.py:13: UserWarning: Encountered NAN log_pdf at site 'module$$$linear.weight'
  warnings.warn("Encountered NAN log_pdf at site '{}'".format(name))
/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/trace.py:13: UserWarning: Encountered NAN log_pdf at site 'module$$$linear.bias'
  warnings.warn("Encountered NAN log_pdf at site '{}'".format(name))
/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/trace.py:13: UserWarning: Encountered NAN log_pdf at site 'obs'
  warnings.warn("Encountered NAN log_pdf at site '{}'".format(name))
/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/infer/trace_elbo.py:183: UserWarning: Encountered NAN loss
  warnings.warn('Encountered NAN loss')
nan
Traceback (most recent call last):
  File "/pyro_prog.py", line 47, in <module>
    loss = svi.step(datax, datay)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/infer/svi.py", line 98, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/infer/elbo.py", line 65, in loss_and_grads
    return self.which_elbo.loss_and_grads(model, guide, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/infer/trace_elbo.py", line 136, in loss_and_grads
    for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/infer/trace_elbo.py", line 81, in _get_traces
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/trace_poutine.py", line 161, in get_trace
    self(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/trace_poutine.py", line 149, in __call__
    ret = super(TracePoutine, self).__call__(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/poutine.py", line 42, in __call__
    return self.fn(*args, **kwargs)
  File "/home/pyro_prog.py", line 43, in guide
    return lifted_module()
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/__init__.py", line 417, in _fn
    return lifted_fn(name, nn_copy, update_module_params=True, *args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/poutine.py", line 42, in __call__
    return self.fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/__init__.py", line 369, in module
    returned_param = param(full_param_name, param_value, tags=tags)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/__init__.py", line 333, in param
    out_msg = apply_stack(msg)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/util.py", line 168, in apply_stack
    msg["value"] = getattr(frame, "_pyro_{}".format(msg["type"]))(msg)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/lift_poutine.py", line 80, in _pyro_param
    return self._pyro_sample(msg)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/poutine/poutine.py", line 167, in _pyro_sample
    val = fn(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/distributions/distribution.py", line 150, in __call__
    return self.sample(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/pyro_ppl-0.1.2-py3.6.egg/pyro/distributions/gamma.py", line 57, in sample
    np_sample = spr.gamma.rvs(self.alpha.data.cpu().numpy(), scale=theta.data.cpu().numpy())
  File "/usr/local/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py", line 940, in rvs
    raise ValueError("Domain error in arguments.")
ValueError: Domain error in arguments.

Is this a known bug? Or am I missing something here?

Environment:

python 3.6
pyro-ppl 0.1.2
torch 0.3.1
@eb8680
Copy link
Member

eb8680 commented Mar 11, 2018

Hi, I ran your code and found a couple of errors in your model. The exponential distribution you're using as the likelihood expects its rate parameter to be nonnegative, and the gamma prior you're using for the bias only has positive support, but your variational distribution for the bias is normal, which occasionally leads to negative values for linear.bias and prediction_mean in the model and hence domain errors and nans.

The next release of Pyro will include support for a much-expanded torch.distributions, which has a library of constraints and transformations that make handling support mismatch easier. We also plan to implement some ADVI-style automated transformations in our SVI implementations.

@lazypanda1
Copy link
Author

The next release of Pyro will include support for a much-expanded torch.distributions, which has a library of constraints and transformations that make handling support mismatch easier. We also plan to implement some ADVI-style automated transformations in our SVI implementations.

Thats good to know. It will be good to have some warnings/errors in case something is wrong with the model. Otherwise it is a bit difficult to debug it.

@fritzo
Copy link
Member

fritzo commented Mar 27, 2018

Closing after #922. You can now pyro.set_validation(True) to catch this sort of error.

@fritzo fritzo closed this as completed Mar 27, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants