Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

what's the meaning of implementing the hook_compute_diag function? #46

Closed
zengjie617789 opened this issue Mar 21, 2022 · 3 comments
Closed

Comments

@zengjie617789
Copy link

I am confused about the details of the _hook_compute_diag that why should multiply the grad with x, which is the input of the layer before back propagation.
Other implement of fisher matrix is like below:

for n, p in self.model.named_parameters():
                precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset) 

Anyone who can explain this , thank you in advance.

@tfjgeorge
Copy link
Owner

Hi

In the piece of code that you provide, you have an outer loop through each individual example of the dataset used to compute the Fisher. But that is inefficient, and you can get individual gradients using tricks such as what is described in https://arxiv.org/abs/1510.01799

In NNGeometry we leverage such tricks, in order to improve compute efficiency.

@zengjie617789
Copy link
Author

zengjie617789 commented Mar 22, 2022

Thank you your instant response. The respository is awesome.
Here is another problem that when I want to implement these code to some model, such as yolox which is anchor-free and the number of output is more than 403200 without decoding. I am confused how to set the num output . Obviously, It's not wise to set such large num to n_output.
Finally, Could you give some suggestions on this?
thank you in advance.

@tfjgeorge
Copy link
Owner

If I understand correctly your needs, I recommend you use the FIM_MonteCarlo metric instead of the FIM one. In the latter you will need to loop through all 403200 outputs, whereas in the former only the output with non-negligible probability will be sampled.

https://nngeometry.readthedocs.io/en/latest/api/metrics.html#nngeometry.metrics.FIM_MonteCarlo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants