Skip to content

Commit

Permalink
Correctly expose sub-distribution parameters in distributions created by
Browse files Browse the repository at this point in the history
inflated_factory.

PiperOrigin-RevId: 471879189
  • Loading branch information
Googler authored and tensorflower-gardener committed Sep 2, 2022
1 parent 1999fbe commit cc84214
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow_probability/python/distributions/inflated.py
Expand Up @@ -224,12 +224,16 @@ def inflated_factory(default_name, distribution_class, inflated_loc,
def my_init(self,
inflated_loc_logits=None, inflated_loc_probs=None,
name=default_name, **kwargs):
parameters = dict(locals())
if 'distribution' in kwargs:
dist = kwargs['distribution']
else:
dist = distribution_class(**{**kwargs, **more_kwargs})
Inflated.__init__(self, dist, inflated_loc_logits, inflated_loc_probs,
inflated_loc, name=name)
# pylint: disable=protected-access
self._parameters = {**parameters, **more_kwargs}
# pylint: enable=protected-access

def my_parameter_properties(unused_cls, dtype, num_classes=None):
return dict(
Expand Down

0 comments on commit cc84214

Please sign in to comment.