Skip to content

Commit

Permalink
Clone output gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 23, 2024
1 parent 2c0472a commit e2197ea
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
9 changes: 5 additions & 4 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
"""
output_gradient = output_gradient.to(dtype=self.factor_args.gradient_covariance_dtype)
flattened_gradient, count = self._get_flattened_gradient(output_gradient=output_gradient)
if self._gradient_scale != 1.0:
flattened_gradient.mul_(self._gradient_scale)

if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None:
dimension = flattened_gradient.size(1)
Expand All @@ -272,8 +274,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N
device=flattened_gradient.device,
requires_grad=False,
)
self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient,
alpha=self._gradient_scale ** 2.)
self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient)

# This is not necessary as `NUM_GRADIENT_COVARIANCE_PROCESSED` should be identical to
# `NUM_ACTIVATION_COVARIANCE_PROCESSED` in most cases. However, they can be different when using
Expand Down Expand Up @@ -301,7 +302,7 @@ def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.
@torch.no_grad()
def backward_hook(output_gradient: torch.Tensor) -> None:
# Computes and updates pseudo-gradient covariance matrix in the backward pass.
self._update_gradient_covariance_matrix(output_gradient.detach())
self._update_gradient_covariance_matrix(output_gradient.detach().clone())

self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook))

Expand Down Expand Up @@ -476,7 +477,7 @@ def shared_backward_hook(output_gradient: torch.Tensor) -> None:
cached_activation = self._cached_activations.pop()
per_sample_gradient = self._compute_per_sample_gradient(
input_activation=cached_activation.to(device=output_gradient.device),
output_gradient=output_gradient.detach().to(dtype=self.factor_args.per_sample_gradient_dtype),
output_gradient=output_gradient.detach().clone().to(dtype=self.factor_args.per_sample_gradient_dtype),
)
if self._cached_per_sample_gradient is None:
self._cached_per_sample_gradient = per_sample_gradient
Expand Down
8 changes: 6 additions & 2 deletions kronfluence/utils/common/factor_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def test_factor_arguments(strategy: str = "ekfac") -> FactorArguments:
return factor_args


def smart_low_precision_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments:
def smart_low_precision_factor_arguments(
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16
) -> FactorArguments:
factor_args = FactorArguments(strategy=strategy)
factor_args.amp_dtype = dtype
factor_args.activation_covariance_dtype = dtype
Expand All @@ -44,7 +46,9 @@ def reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype =
return factor_args


def extreme_reduce_memory_factor_arguments(strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16) -> FactorArguments:
def extreme_reduce_memory_factor_arguments(
strategy: str = "ekfac", dtype: torch.dtype = torch.bfloat16
) -> FactorArguments:
factor_args = all_low_precision_factor_arguments(strategy=strategy, dtype=dtype)
factor_args.lambda_iterative_aggregate = True
factor_args.cached_activation_cpu_offload = True
Expand Down
4 changes: 1 addition & 3 deletions tests/gpu_tests/test_offload_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,4 @@ def test_cpu_offloads_identical(
)
cached_pairwise_scores = analyzer.load_pairwise_scores(scores_name=scores_name)

assert check_tensor_dict_equivalence(
pairwise_scores, cached_pairwise_scores, atol=ATOL, rtol=RTOL
)
assert check_tensor_dict_equivalence(pairwise_scores, cached_pairwise_scores, atol=ATOL, rtol=RTOL)

0 comments on commit e2197ea

Please sign in to comment.