Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extra -0.5 in log_posterior()? #58

Closed
isaac-cfwong opened this issue Oct 31, 2020 · 7 comments
Closed

Extra -0.5 in log_posterior()? #58

isaac-cfwong opened this issue Oct 31, 2020 · 7 comments

Comments

@isaac-cfwong
Copy link

In https://github.com/piEsposito/blitz-bayesian-deep-learning/blob/master/blitz/modules/weight_sampler.py#L50,

log_posteriors =  -log_sqrt2pi - torch.log(self.sigma) - (((w - self.mu) ** 2)/(2 * self.sigma ** 2)) - 0.5

why is there a -0.5 at the end of the line? The log-likelihood of a Gaussian does not have that -0.5.

@piEsposito
Copy link
Owner

Hello, and thank you so much for using BLiTZ.

This log posteriors refers to the kl divergence between the surrogate posterior and the prior distributions.

Here is the derivation of the KL divergence between two normals:

Screenshot_20201031-095109.png

Hope that solves your doubt.

Best regards,

Pi.

@isaac-cfwong
Copy link
Author

isaac-cfwong commented Oct 31, 2020

Thanks for the reply. The package is really a nice work! But I'm still confused with the -0.5 there. Let me try to explain my question step by step in case I misunderstand anything. With reference to the Bayes by Backprop paper https://arxiv.org/pdf/1505.05424.pdf, Eq. (1) is the cost function that can be computed with Monte Carlo samples:

eq1

This cost function is implemented in

loss = 0
for _ in range(sample_nbr):
outputs = self(inputs)
loss += criterion(outputs, labels)
loss += self.nn_kl_divergence() * complexity_cost_weight
return loss / sample_nbr

The criterion is the minus log-likelihood of the training data, i.e. eq2, and the self.nn_kl_divergence() is the eq3 in the KL divergence between the variational posterior and the prior. This is confirmed by tracing to the following function

def kl_divergence_from_nn(model):
"""
Gathers the KL Divergence from a nn.Module object
Works by gathering each Bayesian layer kl divergence and summing it, doing nothing with the non Bayesian ones
"""
kl_divergence = 0
for module in model.modules():
if isinstance(module, (BayesianModule)):
kl_divergence += module.log_variational_posterior - module.log_prior
return kl_divergence

In the function, there are log_variational_posterior and log_prior. Take the Bayesian Linear Layer as an example, the two variables are computed here

self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior
self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior

The variational posteriors of the weights and the bias are computed using the weight_sampler, then the log_posterior() here should refer to eq4. The log_posterior() is defined here

def log_posterior(self, w=None):
"""
Calculates the log_likelihood for each of the weights sampled as a part of the complexity cost
returns:
torch.tensor with shape []
"""
assert (self.w is not None), "You can only have a log posterior for W if you've already sampled it"
if w is None:
w = self.w
log_sqrt2pi = np.log(np.sqrt(2*self.pi))
log_posteriors = -log_sqrt2pi - torch.log(self.sigma) - (((w - self.mu) ** 2)/(2 * self.sigma ** 2)) - 0.5
return log_posteriors.sum()

, and the proposal distribution of q is chosen to be Gaussian, then the log_posterior() here should be the Gaussian log PDF. And there shouldn't be a -0.5. The extra -0.5 here will not affect the training since it is just a constant added to the cost function, but for clarity, it should be removed.

@sansiro77
Copy link
Contributor

I agree, and the extra -0.5 in log_prior() should also be removed.

@HDRah
Copy link

HDRah commented Mar 22, 2021

I agree

@chipbreaker
Copy link

Yes, I think so as well. The -0.5 in the end is wrong.

@piEsposito
Copy link
Owner

@isaac-cfwong sorry for the ultra-late reply. I agree with you, it should be removed. As it was your idea, I can leave it to you to do the PR or, if you don't want to, I can do it myself.

Thank you so much for finding that mistake of mine and solving it.

@piEsposito
Copy link
Owner

Closing it due to staleness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants