We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I tried to render the model in https://pyro.ai/examples/dirichlet_process_mixture.html using pyro.render_model(model, model_args=(data)). However, I got the following error
pyro.render_model(model, model_args=(data))
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_29904/3865686802.py in <module> 38 pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data) 39 ---> 40 pyro.render_model(model, model_args=(data)) ~/thesis/pyro/pyro/infer/inspect.py in render_model(model, model_args, model_kwargs, filename, render_distributions) 492 :rtype: graphviz.Digraph 493 """ --> 494 relations = get_model_relations(model, model_args, model_kwargs) 495 graph_spec = generate_graph_specification(relations) 496 graph = render_graph(graph_spec, render_distributions=render_distributions) ~/thesis/pyro/pyro/infer/inspect.py in get_model_relations(model, model_args, model_kwargs) 274 with torch.random.fork_rng(), torch.no_grad(), pyro.validation_enabled(False): 275 with TrackProvenance(): --> 276 trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) 277 278 sample_sample = {} ~/thesis/pyro/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(*args, **kwargs) 199 return self.msngr.get_trace() ~/thesis/pyro/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs) 172 ) 173 try: --> 174 ret = self.fn(*args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info() TypeError: model() takes 1 positional argument but 200 were given
import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import pyro from pyro.distributions import * from pyro.infer import SVI from pyro.optim import Adam assert pyro.__version__.startswith('1.8.0') pyro.set_rng_seed(0) data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]), MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]), MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]), MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50]))) N = data.shape[0] def mix_weights(beta): beta1m_cumprod = (1 - beta).cumprod(-1) return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1) def model(data): with pyro.plate("beta_plate", T-1): beta = pyro.sample("beta", Beta(1, alpha)) with pyro.plate("mu_plate", T): mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2), 5 * torch.eye(2))) with pyro.plate("data", N): z = pyro.sample("z", Categorical(mix_weights(beta))) pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(2)), obs=data) pyro.render_model(model, model_args=(data))
The text was updated successfully, but these errors were encountered:
Can you try this:
-pyro.render_model(model, model_args=(data)) +pyro.render_model(model, model_args=(data,))
I think pyro.render_model expects model_args to be a tuple.
pyro.render_model
model_args
tuple
Sorry, something went wrong.
Can you try this: -pyro.render_model(model, model_args=(data)) +pyro.render_model(model, model_args=(data,)) I think pyro.render_model expects model_args to be a tuple.
That worked, thanks!
Gosh we should probably assert isinstance(model_args, tuple)
assert isinstance(model_args, tuple)
model_kwargs
render_model
No branches or pull requests
Issue Description
I tried to render the model in https://pyro.ai/examples/dirichlet_process_mixture.html using
pyro.render_model(model, model_args=(data))
. However, I got the following errorEnvironment
Code Snippet
The text was updated successfully, but these errors were encountered: