-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
what's the meaning of implementing the hook_compute_diag function? #46
Comments
Hi In the piece of code that you provide, you have an outer loop through each individual example of the dataset used to compute the Fisher. But that is inefficient, and you can get individual gradients using tricks such as what is described in https://arxiv.org/abs/1510.01799 In NNGeometry we leverage such tricks, in order to improve compute efficiency. |
Thank you your instant response. The respository is awesome. |
If I understand correctly your needs, I recommend you use the https://nngeometry.readthedocs.io/en/latest/api/metrics.html#nngeometry.metrics.FIM_MonteCarlo |
I am confused about the details of the
_hook_compute_diag
that why should multiply the grad with x, which is the input of the layer before back propagation.Other implement of fisher matrix is like below:
Anyone who can explain this , thank you in advance.
The text was updated successfully, but these errors were encountered: