Skip to content

Commit

Permalink
Register forward hook outside
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 19, 2024
1 parent e1ec936 commit b2d1664
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None:
# Compute and update pseudo-gradient covariance matrix in the backward pass.
self._update_gradient_covariance_matrix(output_gradient.detach())

self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))
self._registered_hooks.append(self.register_forward_hook(forward_hook))

def _release_covariance_matrices(self) -> None:
"""Clears the stored activation and pseudo-gradient covariance matrices from memory."""
Expand Down

0 comments on commit b2d1664

Please sign in to comment.