Skip to content

Commit

Permalink
Avoid in-place for normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 23, 2024
1 parent d469691 commit 2c249ff
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype)
flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient)
if self._gradient_scale != 1.0:
flattened_gradient.mul_(self._gradient_scale)
flattened_gradient = flattened_gradient * self._gradient_scale

if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None:
dimension = flattened_gradient.size(1)
Expand Down Expand Up @@ -302,7 +302,7 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
# Computes and updates pseudo-gradient covariance matrix in the backward pass.
self._update_gradient_covariance_matrix(output_gradient.detach().clone())
self._update_gradient_covariance_matrix(output_gradient.detach())

self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

Expand Down

0 comments on commit 2c249ff

Please sign in to comment.