Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian Mixture Model: Predicting for new-data with Enumeration in Guide #2404

Open
prpr2770 opened this issue Apr 9, 2020 · 4 comments
Open

Comments

@prpr2770
Copy link

prpr2770 commented Apr 9, 2020

Issue Description

While following the tutorial on Gaussian Mixture Models, I noted the following error is generated when you feed the full_guide() the new_data that was created earlier.

ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5

/opt/conda/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

<ipython-input-197-9648525bdb5f> in full_guide(fg_data)
     22         assignment_probs = pyro.param('assignment_probs', torch.ones(len(fg_data), K) / K,
     23                                       constraint=constraints.unit_interval)
---> 24         pyro.sample('assignment_new', dist.Categorical(assignment_probs))
     25 
     26     """

/opt/conda/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    111             msg["is_observed"] = True
    112         # apply the stack and return its return value
--> 113         apply_stack(msg)
    114         return msg["value"]
    115 

/opt/conda/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    191         pointer = pointer + 1
    192 
--> 193         frame._process_message(msg)
    194 
    195         if msg["stop"]:

/opt/conda/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
     13     def _process_message(self, msg):
     14         super()._process_message(msg)
---> 15         return BroadcastMessenger._pyro_sample(msg)
     16 
     17     def __enter__(self):

/opt/conda/lib/python3.7/contextlib.py in inner(*args, **kwds)
     72         def inner(*args, **kwds):
     73             with self._recreate_cm():
---> 74                 return func(*args, **kwds)
     75         return inner
     76 

/opt/conda/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
     57                 if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
     58                     raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 59                         f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
     60                 target_batch_shape[f.dim] = f.size
     61             # Starting from the right, if expected size is None at an index,

ValueError: Shape mismatch inside plate('fg_data') at site assignment_new dim -1, 180 vs 5

The new_assignment parameter shouldn't be having the shape 5, which was possibly obtained from the training dataset which was of length 5.

Environment

For any bugs, please provide the following:

  • OS : Ubuntu 16.04
  • python version: 3.7.7 [GCC 7.3.0]
  • PyTorch version: 1.4.0
  • Pyro version: 1.3.1

Code Snippet

Provide any relevant code snippets and commands run to replicate the issue.

new_data = torch.arange(5.5, 6.0, 0.005)

@config_enumerate
def full_guide(fg_data):
    # Global variables.
    with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
        global_guide(fg_data)

    # Local variables.
    with pyro.plate('fg_data', len(fg_data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(fg_data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment_new', dist.Categorical(assignment_probs))

full_guide(new_data) 
@jun2tong
Copy link

Any update on this? Is there a temporary work around?

@fritzo fritzo changed the title [discussion][bug] Gaussian Mixture Model: Predicting for new-data with Enumeration in Guide Gaussian Mixture Model: Predicting for new-data with Enumeration in Guide Jul 7, 2020
@fritzo
Copy link
Member

fritzo commented Jul 7, 2020

The problem here is that the full_guide() learns a pyro.param "assignment_probs" for your specific training data. To support novel data, you could instead learn an amortized guide, where the "assignment_probs" is not a fixed parameter but instead a function (typically a neural net) that depends on observed data. Here's a sketch of an amortized guide that might work for you

class AmortizedGuide(PyroModule):
    def __init__(self):
        super().__init__()
        self.nn = MyNeuralNet()  # you'll need to decide on an architecture

    @config_enumerate
    def forward(self, fg_data):
        # Global variables.
        with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
            global_guide(fg_data)

        # Local variables.
        with pyro.plate('fg_data', len(fg_data)):
            assignment_probs = self.nn(fg_data)  # these can now handle novel data
            pyro.sample('assignment_new', dist.Categorical(assignment_probs))

full_guide = AmortizedGuide()

@jun2tong
Copy link

Indeed amortizing the local random variable works. If that's the case, is this issue then not a bug anymore?

@fritzo
Copy link
Member

fritzo commented Jul 16, 2020

@jun2tong sure, I guess you could call this more of a question than a bug 🙂

@fritzo fritzo removed the bug label Jul 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants