diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 71c578f..c71280f 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -344,7 +344,7 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # Compute and update pseudo-gradient covariance matrix in the backward pass. self._update_gradient_covariance_matrix(output_gradient.detach()) - self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) + self._registered_hooks.append(self.register_forward_hook(forward_hook)) def _release_covariance_matrices(self) -> None: """Clears the stored activation and pseudo-gradient covariance matrices from memory."""