Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Unable to render the model in Dirichlet Process Mixture Models in Pyro #3082

Closed
dilaragokay opened this issue May 5, 2022 · 3 comments

Comments

@dilaragokay
Copy link
Contributor

dilaragokay commented May 5, 2022

Issue Description

I tried to render the model in https://pyro.ai/examples/dirichlet_process_mixture.html using pyro.render_model(model, model_args=(data)). However, I got the following error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_29904/3865686802.py in <module>
     38         pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data)
     39 
---> 40 pyro.render_model(model, model_args=(data))

~/thesis/pyro/pyro/infer/inspect.py in render_model(model, model_args, model_kwargs, filename, render_distributions)
    492     :rtype: graphviz.Digraph
    493     """
--> 494     relations = get_model_relations(model, model_args, model_kwargs)
    495     graph_spec = generate_graph_specification(relations)
    496     graph = render_graph(graph_spec, render_distributions=render_distributions)

~/thesis/pyro/pyro/infer/inspect.py in get_model_relations(model, model_args, model_kwargs)
    274     with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False):
    275         with TrackProvenance():
--> 276             trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    277 
    278     sample_sample = {}

~/thesis/pyro/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

~/thesis/pyro/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

TypeError: model() takes 1 positional argument but 200 were given

Environment

  • OS: Ubuntu 18.04.3 LTS
  • Python version: 3.7.11
  • PyTorch version: 1.11.0
  • Pyro version: 1.8.0+5ab7da2a

Code Snippet

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn.functional as F

import pyro
from pyro.distributions import *
from pyro.infer import SVI
from pyro.optim import Adam

assert pyro.__version__.startswith('1.8.0')
pyro.set_rng_seed(0)

data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))

N = data.shape[0]

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

def model(data):
    with pyro.plate("beta_plate", T-1):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data)

pyro.render_model(model, model_args=(data))
@ordabayevy
Copy link
Member

Can you try this:

-pyro.render_model(model, model_args=(data))
+pyro.render_model(model, model_args=(data,))

I think pyro.render_model expects model_args to be a tuple.

@dilaragokay
Copy link
Contributor Author

Can you try this:

-pyro.render_model(model, model_args=(data))
+pyro.render_model(model, model_args=(data,))

I think pyro.render_model expects model_args to be a tuple.

That worked, thanks!

@fritzo
Copy link
Member

fritzo commented May 5, 2022

Gosh we should probably assert isinstance(model_args, tuple)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants