diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index 42b8369..3cf0551 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -501,8 +501,8 @@ def _activation_cache_pre_forward(self, inputs: Any) -> Any: else: self._cached_activations.append(cached_activation) + @torch.no_grad() def _lambda_post_forward(self, outputs: torch.Tensor) -> Any: - @torch.no_grad() def backward_hook(output_gradient: torch.Tensor) -> None: cached_activation = self._cached_activations.pop() if self.factor_args.cached_activation_cpu_offload: