Skip to content

Commit

Permalink
fix elbo normalization with multi_sample_guide=True (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Feb 5, 2024
1 parent 01089cf commit 977cbc2
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_model_density(key, latent):
return model_log_density

num_guide_samples = None
for name, site in guide_trace.items():
for site in guide_trace.values():
if site["type"] == "sample":
num_guide_samples = site["value"].shape[0]
break
Expand Down Expand Up @@ -210,8 +210,6 @@ def get_model_density(key, latent):
# log p(z) - log q(z)
elbo_particle = model_log_density - guide_log_density

# log p(z) - log q(z)
elbo_particle = model_log_density - guide_log_density
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
Expand Down

0 comments on commit 977cbc2

Please sign in to comment.