Skip to content

Commit

Permalink
Update effect_handlers.ipynb (#3296)
Browse files Browse the repository at this point in the history
```python
print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5)))
```

prevents the following error message:

```python
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter)
    195 try:
--> 196     log_p = site["fn"].log_prob(
    197         site["value"], *site["args"], **site["kwargs"]
    198     )
    199 except ValueError as e:

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value)
     78 if self._validate_args:
---> 79     self._validate_sample(value)
     80 # compute the variance

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value)
    270 if not isinstance(value, torch.Tensor):
--> 271     raise ValueError('The value argument to log_prob must be a Tensor')
    273 event_dim_start = len(value.size()) - len(self._event_shape)

ValueError: The value argument to log_prob must be a Tensor

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[5], line 9
      6     return _log_joint
      8 scale_log_joint = make_log_joint(scale)
----> 9 print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

Cell In[5], line 5, in make_log_joint.<locals>._log_joint(cond_data, *args, **kwargs)
      3 conditioned_model = poutine.condition(model, data=cond_data)
      4 trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
----> 5 return trace.log_prob_sum()

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:202, in Trace.log_prob_sum(self, site_filter)
    200     _, exc_value, traceback = sys.exc_info()
    201     shapes = self.format_shapes(last_site=site["name"])
--> 202     raise ValueError(
    203         "Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format(
    204             name, exc_value, shapes
    205         )
    206     ).with_traceback(traceback) from e
    207 log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
    208 site["log_prob_sum"] = log_p

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:196, in Trace.log_prob_sum(self, site_filter)
    194 else:
    195     try:
--> 196         log_p = site["fn"].log_prob(
    197             site["value"], *site["args"], **site["kwargs"]
    198         )
    199     except ValueError as e:
    200         _, exc_value, traceback = sys.exc_info()

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/normal.py:79, in Normal.log_prob(self, value)
     77 def log_prob(self, value):
     78     if self._validate_args:
---> 79         self._validate_sample(value)
     80     # compute the variance
     81     var = (self.scale ** 2)

File ~/.pyenv/versions/pyciemss-main/lib/python3.10/site-packages/torch/distributions/distribution.py:271, in Distribution._validate_sample(self, value)
    257 """
    258 Argument validation for distribution methods such as `log_prob`,
    259 `cdf` and `icdf`. The rightmost dimensions of a value to be
   (...)
    268         distribution's batch and event shapes.
    269 """
    270 if not isinstance(value, torch.Tensor):
--> 271     raise ValueError('The value argument to log_prob must be a Tensor')
    273 event_dim_start = len(value.size()) - len(self._event_shape)
    274 if value.size()[event_dim_start:] != self._event_shape:

ValueError: Error while computing log_prob_sum at site 'weight':
The value argument to log_prob must be a Tensor
Trace Shapes:  
 Param Sites:  
Sample Sites:  
  weight dist |
        value |
```
  • Loading branch information
djinnome committed Dec 4, 2023
1 parent 4df2c1e commit a14fabc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tutorial/source/effect_handlers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
" return _log_joint\n",
"\n",
"scale_log_joint = make_log_joint(scale)\n",
"print(scale_log_joint({\"measurement\": 9.5, \"weight\": 8.23}, 8.5))"
"print(scale_log_joint({\"measurement\": torch.tensor(9.5), \"weight\": torch.tensor(8.23)}, torch.tensor(8.5)))"
]
},
{
Expand Down

0 comments on commit a14fabc

Please sign in to comment.