You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While following the tutorial on Gaussian Mixture Models, I noted the following error is generated when you feed the full_guide() the new_data that was created earlier.
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
/opt/conda/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
<ipython-input-197-9648525bdb5f> in full_guide(fg_data)
22 assignment_probs = pyro.param('assignment_probs', torch.ones(len(fg_data), K) / K,
23 constraint=constraints.unit_interval)
---> 24 pyro.sample('assignment_new', dist.Categorical(assignment_probs))
25
26 """
/opt/conda/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/opt/conda/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
191 pointer = pointer + 1
192
--> 193 frame._process_message(msg)
194
195 if msg["stop"]:
/opt/conda/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
13 def _process_message(self, msg):
14 super()._process_message(msg)
---> 15 return BroadcastMessenger._pyro_sample(msg)
16
17 def __enter__(self):
/opt/conda/lib/python3.7/contextlib.py in inner(*args, **kwds)
72 def inner(*args, **kwds):
73 with self._recreate_cm():
---> 74 return func(*args, **kwds)
75 return inner
76
/opt/conda/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
57 if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
58 raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 59 f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
60 target_batch_shape[f.dim] = f.size
61 # Starting from the right, if expected size is None at an index,
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
The new_assignment parameter shouldn't be having the shape 5, which was possibly obtained from the training dataset which was of length 5.
Environment
For any bugs, please provide the following:
OS : Ubuntu 16.04
python version: 3.7.7 [GCC 7.3.0]
PyTorch version: 1.4.0
Pyro version: 1.3.1
Code Snippet
Provide any relevant code snippets and commands run to replicate the issue.
new_data=torch.arange(5.5, 6.0, 0.005)
@config_enumeratedeffull_guide(fg_data):
# Global variables.withpoutine.block(hide_types=["param"]): # Keep our learned values of global parameters.global_guide(fg_data)
# Local variables.withpyro.plate('fg_data', len(fg_data)):
assignment_probs=pyro.param('assignment_probs', torch.ones(len(fg_data), K) /K,
constraint=constraints.unit_interval)
pyro.sample('assignment_new', dist.Categorical(assignment_probs))
full_guide(new_data)
The text was updated successfully, but these errors were encountered:
Any update on this? Is there a temporary work around?
fritzo
changed the title
[discussion][bug] Gaussian Mixture Model: Predicting for new-data with Enumeration in Guide
Gaussian Mixture Model: Predicting for new-data with Enumeration in Guide
Jul 7, 2020
The problem here is that the full_guide() learns a pyro.param "assignment_probs" for your specific training data. To support novel data, you could instead learn an amortized guide, where the "assignment_probs" is not a fixed parameter but instead a function (typically a neural net) that depends on observed data. Here's a sketch of an amortized guide that might work for you
classAmortizedGuide(PyroModule):
def__init__(self):
super().__init__()
self.nn=MyNeuralNet() # you'll need to decide on an architecture@config_enumeratedefforward(self, fg_data):
# Global variables.withpoutine.block(hide_types=["param"]): # Keep our learned values of global parameters.global_guide(fg_data)
# Local variables.withpyro.plate('fg_data', len(fg_data)):
assignment_probs=self.nn(fg_data) # these can now handle novel datapyro.sample('assignment_new', dist.Categorical(assignment_probs))
full_guide=AmortizedGuide()
Issue Description
While following the tutorial on Gaussian Mixture Models, I noted the following error is generated when you feed the full_guide() the new_data that was created earlier.
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5
The new_assignment parameter shouldn't be having the shape 5, which was possibly obtained from the training dataset which was of length 5.
Environment
For any bugs, please provide the following:
Code Snippet
Provide any relevant code snippets and commands run to replicate the issue.
The text was updated successfully, but these errors were encountered: