Skip to content

Commit

Permalink
Pass count as tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jun 20, 2024
1 parent 68f5c71 commit 4ba16aa
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,16 @@ def __init__(
# Operations that will be performed before and after a forward pass.
self._pre_forward = do_nothing
self._post_forward = do_nothing
self._num_forward_passes = 0
self._num_backward_passes = 0
self._num_forward_passes = torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)
self._num_backward_passes = torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)

if factor_args is None:
factor_args = FactorArguments()
Expand Down Expand Up @@ -170,8 +178,16 @@ def set_mode(self, mode: ModuleMode, keep_factors: bool = True) -> None:
"""Sets the module mode of all `TrackedModule` instances within a model."""
self.remove_attention_mask()
self.remove_gradient_scale()
self._num_forward_passes = 0
self._num_backward_passes = 0
self._num_forward_passes = torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)
self._num_backward_passes = torch.zeros(
1,
requires_grad=False,
dtype=torch.int64,
)

if not keep_factors:
self._release_covariance_matrices()
Expand Down

0 comments on commit 4ba16aa

Please sign in to comment.