From b2d16645f7a675100ee7aa78402117971a1881d0 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 19 Jun 2024 16:24:12 -0400 Subject: [PATCH] Register forward hook outside --- kronfluence/module/tracked_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 71c578f..c71280f 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -344,7 +344,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # Compute and update pseudo-gradient covariance matrix in the backward pass. self._update_gradient_covariance_matrix(output_gradient.detach()) - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) + self._registered_hooks.append(self.register_forward_hook(forward_hook)) def _release_covariance_matrices(self) -> None: """Clears the stored activation and pseudo-gradient covariance matrices from memory."""